diff --git a/benchmarks/kernels/benchmark_paged_attention.py b/benchmarks/kernels/benchmark_paged_attention.py
index daedaadb1a7..98645a0cd4c 100644
--- a/benchmarks/kernels/benchmark_paged_attention.py
+++ b/benchmarks/kernels/benchmark_paged_attention.py
@@ -118,6 +118,7 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
                     block_size,
                     max_seq_len,
                     alibi_slopes,
+                    None,  # TODO add custom bias
                     kv_cache_dtype,
                     k_scale,
                     v_scale,
@@ -138,6 +139,7 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
                     block_size,
                     max_seq_len,
                     alibi_slopes,
+                    None,
                     kv_cache_dtype,
                     k_scale,
                     v_scale,
diff --git a/csrc/attention/attention_kernels.cuh b/csrc/attention/attention_kernels.cuh
index eb216dc8baf..d19771a19cc 100644
--- a/csrc/attention/attention_kernels.cuh
+++ b/csrc/attention/attention_kernels.cuh
@@ -104,6 +104,8 @@ __device__ void paged_attention_kernel(
     const int* __restrict__ seq_lens,      // [num_seqs]
     const int max_num_blocks_per_seq,
     const float* __restrict__ alibi_slopes,  // [num_heads]
+    const float* __restrict__ attn_bias,  // [num_seqs, num_heads, max_seq_len]
+    const int padded_max_seq_len,         // Avoid recomputing from seq_lens.
     const int q_stride, const int kv_block_stride, const int kv_head_stride,
     const float* k_scale, const float* v_scale, const int tp_rank,
     const int blocksparse_local_blocks, const int blocksparse_vert_stride,
@@ -154,6 +156,14 @@ __device__ void paged_attention_kernel(
   const float alibi_slope =
       alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx];
 
+  // NOTE (NickLucche) `max_seq_len` (padded) bias values for current sequence
+  // and current head.
+  const float* attn_bias_vec =
+      attn_bias == nullptr
+          ? nullptr
+          : attn_bias + seq_idx * num_heads * padded_max_seq_len +
+                head_idx * padded_max_seq_len;
+
   // A vector type to store a part of a key or a query.
   // The vector size is configured in such a way that the threads in a thread
   // group fetch or compute 16 bytes at a time. For example, if the size of a
@@ -293,8 +303,10 @@ __device__ void paged_attention_kernel(
       // This includes a reduction across the threads in the same thread group.
       float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(
                              q_vecs[thread_group_offset], k_vecs);
-      // Add the ALiBi bias if slopes are given.
+      // Add the ALiBi bias if slopes are given, then add custom bias if given.
+      // TODO mutually exclusive?
       qk += (alibi_slope != 0) ? alibi_slope * (token_idx - seq_len + 1) : 0;
+      qk += (attn_bias_vec != nullptr) ? attn_bias_vec[token_idx] : 0;
 
       if (thread_group_offset == 0) {
         // Store the partial reductions to shared memory.
@@ -512,6 +524,8 @@ __global__ void paged_attention_v1_kernel(
     const int* __restrict__ seq_lens,      // [num_seqs]
     const int max_num_blocks_per_seq,
     const float* __restrict__ alibi_slopes,  // [num_heads]
+    const float* __restrict__ attn_bias,
+    const int padded_max_seq_len,  // Avoid recomputing from seq_lens.
     const int q_stride, const int kv_block_stride, const int kv_head_stride,
     const float* k_scale, const float* v_scale, const int tp_rank,
     const int blocksparse_local_blocks, const int blocksparse_vert_stride,
@@ -520,9 +534,9 @@ __global__ void paged_attention_v1_kernel(
                          KV_DTYPE, IS_BLOCK_SPARSE>(
       /* exp_sums */ nullptr, /* max_logits */ nullptr, out, q, k_cache,
       v_cache, num_kv_heads, scale, block_tables, seq_lens,
-      max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride,
-      kv_head_stride, k_scale, v_scale, tp_rank, blocksparse_local_blocks,
-      blocksparse_vert_stride, blocksparse_block_size,
+      max_num_blocks_per_seq, alibi_slopes, attn_bias, padded_max_seq_len,
+      q_stride, kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank,
+      blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_block_size,
       blocksparse_head_sliding_step);
 }
 
@@ -548,6 +562,8 @@ __global__ void paged_attention_v2_kernel(
     const int* __restrict__ seq_lens,      // [num_seqs]
     const int max_num_blocks_per_seq,
     const float* __restrict__ alibi_slopes,  // [num_heads]
+    const float* __restrict__ attn_bias,
+    const int padded_max_seq_len,  // Avoid recomputing from seq_lens.
     const int q_stride, const int kv_block_stride, const int kv_head_stride,
     const float* k_scale, const float* v_scale, const int tp_rank,
     const int blocksparse_local_blocks, const int blocksparse_vert_stride,
@@ -555,10 +571,10 @@ __global__ void paged_attention_v2_kernel(
   paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
                          KV_DTYPE, IS_BLOCK_SPARSE, PARTITION_SIZE>(
       exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale,
-      block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, q_stride,
-      kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank,
-      blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_block_size,
-      blocksparse_head_sliding_step);
+      block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, attn_bias,
+      padded_max_seq_len, q_stride, kv_block_stride, kv_head_stride, k_scale,
+      v_scale, tp_rank, blocksparse_local_blocks, blocksparse_vert_stride,
+      blocksparse_block_size, blocksparse_head_sliding_step);
 }
 
 // Grid: (num_heads, num_seqs).
diff --git a/csrc/attention/paged_attention_v1.cu b/csrc/attention/paged_attention_v1.cu
index 9b3a5c4b101..c32e852204d 100644
--- a/csrc/attention/paged_attention_v1.cu
+++ b/csrc/attention/paged_attention_v1.cu
@@ -29,21 +29,21 @@
 #define MIN(a, b) ((a) < (b) ? (a) : (b))
 #define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
 
-#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE)                                \
-  VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(                     \
-      ((void*)vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE,        \
-                                              BLOCK_SIZE, NUM_THREADS,      \
-                                              KV_DTYPE, IS_BLOCK_SPARSE>),  \
-      shared_mem_size);                                                     \
-  vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE,        \
-                                  NUM_THREADS, KV_DTYPE, IS_BLOCK_SPARSE>   \
-      <<<grid, block, shared_mem_size, stream>>>(                           \
-          out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, \
-          scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq,    \
-          alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride,      \
-          k_scale_ptr, v_scale_ptr, tp_rank, blocksparse_local_blocks,      \
-          blocksparse_vert_stride, blocksparse_block_size,                  \
-          blocksparse_head_sliding_step);
+#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE)                                  \
+  VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(                       \
+      ((void*)vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE,          \
+                                              BLOCK_SIZE, NUM_THREADS,        \
+                                              KV_DTYPE, IS_BLOCK_SPARSE>),    \
+      shared_mem_size);                                                       \
+  vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE,          \
+                                  NUM_THREADS, KV_DTYPE, IS_BLOCK_SPARSE>     \
+      <<<grid, block, shared_mem_size, stream>>>(                             \
+          out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads,   \
+          scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq,      \
+          alibi_slopes_ptr, attn_bias_ptr, padded_max_seq_len, q_stride,      \
+          kv_block_stride, kv_head_stride, k_scale_ptr, v_scale_ptr, tp_rank, \
+          blocksparse_local_blocks, blocksparse_vert_stride,                  \
+          blocksparse_block_size, blocksparse_head_sliding_step);
 
 // TODO(woosuk): Tune NUM_THREADS.
 template <typename T, typename CACHE_T, int BLOCK_SIZE,
@@ -53,7 +53,8 @@ void paged_attention_v1_launcher(
     torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
     torch::Tensor& value_cache, int num_kv_heads, float scale,
     torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
-    const std::optional<torch::Tensor>& alibi_slopes, torch::Tensor& k_scale,
+    const std::optional<torch::Tensor>& alibi_slopes,
+    const std::optional<torch::Tensor>& attn_bias, torch::Tensor& k_scale,
     torch::Tensor& v_scale, const int tp_rank,
     const int blocksparse_local_blocks, const int blocksparse_vert_stride,
     const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
@@ -73,7 +74,21 @@ void paged_attention_v1_launcher(
       alibi_slopes
           ? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
           : nullptr;
-
+  const float* attn_bias_ptr =
+      attn_bias ? reinterpret_cast<const float*>(attn_bias.value().data_ptr())
+                : nullptr;
+  const int padded_max_seq_len =
+      DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE) * BLOCK_SIZE;
+  if (attn_bias_ptr) {
+    const torch::Tensor& abias = attn_bias.value();
+    TORCH_CHECK(abias.dtype() == torch::kFloat32,
+                "Unsupported bias dtype: ", abias.dtype());
+    TORCH_CHECK(abias.size(abias.dim() - 1) == padded_max_seq_len,
+                "The last dimension of the attention bias must "
+                "match the block-aligned maximum sequence length (",
+                padded_max_seq_len,
+                "). However, the given dimensions are: ", abias.sizes());
+  }
   T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
   T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
   CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr());
@@ -84,13 +99,11 @@ void paged_attention_v1_launcher(
   const float* v_scale_ptr = reinterpret_cast<const float*>(v_scale.data_ptr());
 
   constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
-  int padded_max_seq_len =
-      DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE) * BLOCK_SIZE;
-  int logits_size = padded_max_seq_len * sizeof(float);
-  int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
+  const int logits_size = padded_max_seq_len * sizeof(float);
+  const int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
   // Python-side check in vllm.worker.worker._check_if_can_support_max_seq_len
   // Keep that in sync with the logic here!
-  int shared_mem_size = std::max(logits_size, outputs_size);
+  const int shared_mem_size = std::max(logits_size, outputs_size);
 
   dim3 grid(num_heads, num_seqs, 1);
   dim3 block(NUM_THREADS);
@@ -137,8 +150,8 @@ void paged_attention_v1_launcher(
   paged_attention_v1_launcher<T, CACHE_T, BLOCK_SIZE, KV_DTYPE,              \
                               IS_BLOCK_SPARSE>(                              \
       out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \
-      seq_lens, max_seq_len, alibi_slopes, k_scale, v_scale, tp_rank,        \
-      blocksparse_local_blocks, blocksparse_vert_stride,                     \
+      seq_lens, max_seq_len, alibi_slopes, attn_bias, k_scale, v_scale,      \
+      tp_rank, blocksparse_local_blocks, blocksparse_vert_stride,            \
       blocksparse_block_size, blocksparse_head_sliding_step);
 
 #define CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \
@@ -179,6 +192,7 @@ void paged_attention_v1(
     torch::Tensor& seq_lens,      // [num_seqs]
     int64_t block_size, int64_t max_seq_len,
     const std::optional<torch::Tensor>& alibi_slopes,
+    const std::optional<torch::Tensor>& attn_bias,
     const std::string& kv_cache_dtype, torch::Tensor& k_scale,
     torch::Tensor& v_scale, const int64_t tp_rank,
     const int64_t blocksparse_local_blocks,
diff --git a/csrc/attention/paged_attention_v2.cu b/csrc/attention/paged_attention_v2.cu
index 9935359e02f..ccfd6cd60f5 100644
--- a/csrc/attention/paged_attention_v2.cu
+++ b/csrc/attention/paged_attention_v2.cu
@@ -36,8 +36,9 @@
       <<<grid, block, shared_mem_size, stream>>>(                              \
           exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, \
           value_cache_ptr, num_kv_heads, scale, block_tables_ptr,              \
-          seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride,    \
-          kv_block_stride, kv_head_stride, k_scale_ptr, v_scale_ptr, tp_rank,  \
+          seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr,              \
+          attn_bias_ptr, padded_max_seq_len, q_stride, kv_block_stride,        \
+          kv_head_stride, k_scale_ptr, v_scale_ptr, tp_rank,                   \
           blocksparse_local_blocks, blocksparse_vert_stride,                   \
           blocksparse_block_size, blocksparse_head_sliding_step);              \
   vllm::paged_attention_v2_reduce_kernel<T, HEAD_SIZE, NUM_THREADS,            \
@@ -54,7 +55,8 @@ void paged_attention_v2_launcher(
     torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
     torch::Tensor& value_cache, int num_kv_heads, float scale,
     torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
-    const std::optional<torch::Tensor>& alibi_slopes, torch::Tensor& k_scale,
+    const std::optional<torch::Tensor>& alibi_slopes,
+    const std::optional<torch::Tensor>& attn_bias, torch::Tensor& k_scale,
     torch::Tensor& v_scale, const int tp_rank,
     const int blocksparse_local_blocks, const int blocksparse_vert_stride,
     const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
@@ -74,7 +76,21 @@ void paged_attention_v2_launcher(
       alibi_slopes
           ? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
           : nullptr;
-
+  const float* attn_bias_ptr =
+      attn_bias ? reinterpret_cast<const float*>(attn_bias.value().data_ptr())
+                : nullptr;
+  const int padded_max_seq_len =
+      DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE) * BLOCK_SIZE;
+  if (attn_bias_ptr) {
+    const torch::Tensor& abias = attn_bias.value();
+    TORCH_CHECK(abias.dtype() == torch::kFloat32,
+                "Unsupported bias dtype: ", abias.dtype());
+    TORCH_CHECK(abias.size(abias.dim() - 1) == padded_max_seq_len,
+                "The last dimension of the attention bias must "
+                "match the block-aligned maximum sequence length (",
+                padded_max_seq_len,
+                "). However, the given dimensions are: ", abias.sizes());
+  }
   T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
   float* exp_sums_ptr = reinterpret_cast<float*>(exp_sums.data_ptr());
   float* max_logits_ptr = reinterpret_cast<float*>(max_logits.data_ptr());
@@ -88,16 +104,16 @@ void paged_attention_v2_launcher(
   const float* v_scale_ptr = reinterpret_cast<const float*>(v_scale.data_ptr());
 
   constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
-  int max_num_partitions = DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE);
-  int logits_size = PARTITION_SIZE * sizeof(float);
-  int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
+  const int max_num_partitions = DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE);
+  const int logits_size = PARTITION_SIZE * sizeof(float);
+  const int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
 
   // For paged attention v2 kernel.
   dim3 grid(num_heads, num_seqs, max_num_partitions);
-  int shared_mem_size = std::max(logits_size, outputs_size);
+  const int shared_mem_size = std::max(logits_size, outputs_size);
   // For paged attention v2 reduce kernel.
   dim3 reduce_grid(num_heads, num_seqs);
-  int reduce_shared_mem_size = 2 * max_num_partitions * sizeof(float);
+  const int reduce_shared_mem_size = 2 * max_num_partitions * sizeof(float);
 
   dim3 block(NUM_THREADS);
   const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
@@ -144,7 +160,7 @@ void paged_attention_v2_launcher(
                               IS_BLOCK_SPARSE>(                               \
       out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache,      \
       num_kv_heads, scale, block_tables, seq_lens, max_seq_len, alibi_slopes, \
-      k_scale, v_scale, tp_rank, blocksparse_local_blocks,                    \
+      attn_bias, k_scale, v_scale, tp_rank, blocksparse_local_blocks,         \
       blocksparse_vert_stride, blocksparse_block_size,                        \
       blocksparse_head_sliding_step);
 
@@ -190,6 +206,7 @@ void paged_attention_v2(
     torch::Tensor& seq_lens,      // [num_seqs]
     int64_t block_size, int64_t max_seq_len,
     const std::optional<torch::Tensor>& alibi_slopes,
+    const std::optional<torch::Tensor>& attn_bias,
     const std::string& kv_cache_dtype, torch::Tensor& k_scale,
     torch::Tensor& v_scale, const int64_t tp_rank,
     const int64_t blocksparse_local_blocks,
diff --git a/csrc/cpu/attention.cpp b/csrc/cpu/attention.cpp
index b9764056e8a..4c67a775a0b 100644
--- a/csrc/cpu/attention.cpp
+++ b/csrc/cpu/attention.cpp
@@ -459,7 +459,8 @@ void paged_attention_v1(
     torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
     torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
     torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
-    int64_t max_seq_len, const std::optional<torch::Tensor>& alibi_slopes,
+    int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
+    const c10::optional<torch::Tensor>& attn_bias,
     const std::string& kv_cache_dtype, torch::Tensor& k_scale,
     torch::Tensor& v_scale, const int64_t tp_rank,
     const int64_t blocksparse_local_blocks,
@@ -467,6 +468,8 @@ void paged_attention_v1(
     const int64_t blocksparse_head_sliding_step) {
   TORCH_CHECK(blocksparse_vert_stride <= 1,
               "CPU backend does not support blocksparse attention yet.");
+  TORCH_CHECK(!attn_bias.has_value(),
+              "CPU backend does not support custom attention bias.");
   VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v1_impl",
                                [&] {
                                  CPU_KERNEL_GUARD_IN(paged_attention_v1_impl)
@@ -782,6 +785,7 @@ void paged_attention_v2(
     torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
     torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
     int64_t max_seq_len, const std::optional<torch::Tensor>& alibi_slopes,
+    const std::optional<torch::Tensor>& attn_bias,
     const std::string& kv_cache_dtype, torch::Tensor& k_scale,
     torch::Tensor& v_scale, const int64_t tp_rank,
     const int64_t blocksparse_local_blocks,
@@ -789,6 +793,8 @@ void paged_attention_v2(
     const int64_t blocksparse_head_sliding_step) {
   TORCH_CHECK(blocksparse_vert_stride <= 1,
               "CPU backend does not support blocksparse attention yet.");
+  TORCH_CHECK(!attn_bias.has_value(),
+              "CPU backend does not support custom attention bias.");
   VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v2_impl",
                                [&] {
                                  CPU_KERNEL_GUARD_IN(paged_attention_v2_impl)
diff --git a/csrc/cpu/torch_bindings.cpp b/csrc/cpu/torch_bindings.cpp
index 5d1c5f4c83d..9ddb837d182 100644
--- a/csrc/cpu/torch_bindings.cpp
+++ b/csrc/cpu/torch_bindings.cpp
@@ -24,12 +24,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
   // Attention ops
   // Compute the attention between an input query and the cached keys/values
   // using PagedAttention.
+  // TODO attn_bias on cpu
   ops.def(
       "paged_attention_v1("
       "    Tensor! out, Tensor query, Tensor key_cache,"
       "    Tensor value_cache, int num_kv_heads, float scale,"
       "    Tensor block_tables, Tensor seq_lens, int block_size,"
-      "    int max_seq_len, Tensor? alibi_slopes,"
+      "    int max_seq_len, Tensor? alibi_slopes, Tensor? attn_bias,"
       "    str kv_cache_dtype, Tensor k_scale, Tensor v_scale,"
       "    int tp_rank, int blocksparse_local_blocks,"
       "    int blocksparse_vert_stride, int blocksparse_block_size,"
@@ -43,7 +44,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
       "    Tensor! tmp_out, Tensor query, Tensor key_cache,"
       "    Tensor value_cache, int num_kv_heads, float scale,"
       "    Tensor block_tables, Tensor seq_lens, int block_size,"
-      "    int max_seq_len, Tensor? alibi_slopes,"
+      "    int max_seq_len, Tensor? alibi_slopes, Tensor? attn_bias,"
       "    str kv_cache_dtype, Tensor k_scale, Tensor v_scale,"
       "    int tp_rank, int blocksparse_local_blocks,"
       "    int blocksparse_vert_stride, int blocksparse_block_size,"
diff --git a/csrc/ops.h b/csrc/ops.h
index e39d4ef3188..6a467fd6f37 100644
--- a/csrc/ops.h
+++ b/csrc/ops.h
@@ -33,7 +33,8 @@ void paged_attention_v1(
     torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
     torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
     torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
-    int64_t max_seq_len, const std::optional<torch::Tensor>& alibi_slopes,
+    int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
+    const c10::optional<torch::Tensor>& attn_bias,
     const std::string& kv_cache_dtype, torch::Tensor& k_scale,
     torch::Tensor& v_scale, const int64_t tp_rank,
     const int64_t blocksparse_local_blocks,
@@ -45,7 +46,8 @@ void paged_attention_v2(
     torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
     torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
     torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
-    int64_t max_seq_len, const std::optional<torch::Tensor>& alibi_slopes,
+    int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
+    const c10::optional<torch::Tensor>& attn_bias,
     const std::string& kv_cache_dtype, torch::Tensor& k_scale,
     torch::Tensor& v_scale, const int64_t tp_rank,
     const int64_t blocksparse_local_blocks,
diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp
index c03806f430a..29e91762c74 100644
--- a/csrc/torch_bindings.cpp
+++ b/csrc/torch_bindings.cpp
@@ -29,7 +29,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
       "    Tensor! out, Tensor query, Tensor key_cache,"
       "    Tensor value_cache, int num_kv_heads, float scale,"
       "    Tensor block_tables, Tensor seq_lens, int block_size,"
-      "    int max_seq_len, Tensor? alibi_slopes,"
+      "    int max_seq_len, Tensor? alibi_slopes, Tensor? attn_bias,"
       "    str kv_cache_dtype, Tensor k_scale, Tensor v_scale,"
       "    int tp_rank, int blocksparse_local_blocks,"
       "    int blocksparse_vert_stride, int blocksparse_block_size,"
@@ -43,7 +43,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
       "    Tensor! tmp_out, Tensor query, Tensor key_cache,"
       "    Tensor value_cache, int num_kv_heads, float scale,"
       "    Tensor block_tables, Tensor seq_lens, int block_size,"
-      "    int max_seq_len, Tensor? alibi_slopes,"
+      "    int max_seq_len, Tensor? alibi_slopes, Tensor? attn_bias,"
       "    str kv_cache_dtype, Tensor k_scale, Tensor v_scale,"
       "    int tp_rank, int blocksparse_local_blocks,"
       "    int blocksparse_vert_stride, int blocksparse_block_size,"
diff --git a/tests/kernels/test_attention.py b/tests/kernels/test_attention.py
index b667d8d9e03..314db5db121 100644
--- a/tests/kernels/test_attention.py
+++ b/tests/kernels/test_attention.py
@@ -39,6 +39,7 @@
 
 BLOCK_SIZES = [16, 32]
 USE_ALIBI = [False, True]
+USE_CUSTOM_ATTN_BIAS = [False, True]
 KV_CACHE_DTYPE = ["auto", "fp8"]
 SEEDS = [0]
 CUDA_DEVICES = [
@@ -62,16 +63,11 @@ def ref_masked_attention(
 
 
 def ref_single_query_cached_kv_attention(
-    output: torch.Tensor,
-    query: torch.Tensor,
-    num_queries_per_kv: int,
-    key_cache: torch.Tensor,
-    value_cache: torch.Tensor,
-    block_tables: torch.Tensor,
-    seq_lens: torch.Tensor,
-    scale: float,
-    alibi_slopes: Optional[torch.Tensor],
-) -> None:
+        output: torch.Tensor, query: torch.Tensor, num_queries_per_kv: int,
+        key_cache: torch.Tensor, value_cache: torch.Tensor,
+        block_tables: torch.Tensor, seq_lens: torch.Tensor, scale: float,
+        alibi_slopes: Optional[torch.Tensor],
+        attn_bias: Optional[List[torch.Tensor]]) -> None:
     num_query_heads = query.shape[1]
     num_kv_heads = value_cache.shape[1]
     head_size = value_cache.shape[2]
@@ -104,15 +100,17 @@ def ref_single_query_cached_kv_attention(
             keys = torch.repeat_interleave(keys, num_queries_per_kv, dim=1)
             values = torch.repeat_interleave(values, num_queries_per_kv, dim=1)
 
-        alibi_bias = None
+        bias = None
         if alibi_slopes is not None:
             # Create the ALiBi bias used in the paged attention kernel.
             position_ids = torch.arange(seq_len).int()
             alibi_bias = (position_ids - seq_len + 1).float()
             alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view(
                 1, 1, -1)
-
-        out = ref_masked_attention(q, keys, values, scale, alibi_bias)
+            bias = alibi_bias
+        if attn_bias is not None:
+            bias = attn_bias[i] if bias is None else bias + attn_bias[i]
+        out = ref_masked_attention(q, keys, values, scale, bias)
         out = out.view(num_query_heads, head_size)
         output[i].copy_(out, non_blocking=True)
 
@@ -124,6 +122,7 @@ def ref_single_query_cached_kv_attention(
 @pytest.mark.parametrize("num_heads", NUM_HEADS)
 @pytest.mark.parametrize("head_size", HEAD_SIZES)
 @pytest.mark.parametrize("use_alibi", USE_ALIBI)
+@pytest.mark.parametrize("use_custom_attn_bias", USE_CUSTOM_ATTN_BIAS)
 @pytest.mark.parametrize("block_size", BLOCK_SIZES)
 @pytest.mark.parametrize("dtype", DTYPES)
 @pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
@@ -136,6 +135,7 @@ def test_paged_attention(
     num_heads: Tuple[int, int],
     head_size: int,
     use_alibi: bool,
+    use_custom_attn_bias: bool,
     block_size: int,
     dtype: torch.dtype,
     kv_cache_dtype: str,
@@ -155,7 +155,7 @@ def test_paged_attention(
 
     assert num_query_heads % num_kv_heads == 0
     num_queries_per_kv = num_query_heads // num_kv_heads
-    alibi_slopes = None
+    alibi_slopes, attn_bias = None, None
     if use_alibi:
         alibi_slopes = torch.randn(num_query_heads, dtype=torch.float)
 
@@ -163,9 +163,30 @@ def test_paged_attention(
     seq_lens[-1] = MAX_SEQ_LEN
     max_seq_len = max(seq_lens)
     seq_lens = torch.tensor(seq_lens, dtype=torch.int)
+    attn_bias_list = None
+    max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
+    if use_custom_attn_bias:
+        # NOTE (NickLucche) each sequence can have a different bias,
+        # depending on its len, but it *must* be padded to the block
+        # aligned max_seq_len and of type float32!
+        attn_bias_list = [
+            torch.randn(num_query_heads, 1, seq_len, dtype=torch.float)
+            for seq_len in seq_lens
+        ]
+        block_aligned_max_seq_len = max_num_blocks_per_seq * block_size
+        attn_bias = torch.empty(
+            num_seqs,
+            num_query_heads,
+            1,
+            block_aligned_max_seq_len,  # padded dim
+            device=device,
+            dtype=torch.float)
+
+        for i, (seq_len, bias) in enumerate(zip(seq_lens, attn_bias_list)):
+            # first `seq_len` entries of the bias matrix for each head/seq
+            attn_bias[i, :, :, :seq_len] = bias
 
     # Create the block tables.
-    max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
     block_tables_lst: List[List[int]] = []
     for _ in range(num_seqs):
         block_table = [
@@ -201,6 +222,7 @@ def test_paged_attention(
             block_size,
             max_seq_len,
             alibi_slopes,
+            attn_bias,
             kv_cache_dtype,
             k_scale,
             v_scale,
@@ -209,7 +231,7 @@ def test_paged_attention(
         opcheck(torch.ops._C.paged_attention_v1,
                 (output, query, key_cache, value_cache, num_kv_heads, scale,
                  block_tables, seq_lens, block_size, max_seq_len, alibi_slopes,
-                 kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0),
+                 attn_bias, kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0),
                 cond=(head_size == HEAD_SIZES[0]
                       and block_size == BLOCK_SIZES[0]))
 
@@ -242,18 +264,20 @@ def test_paged_attention(
                 block_size,
                 max_seq_len,
                 alibi_slopes,
+                attn_bias,
                 kv_cache_dtype,
                 k_scale,
                 v_scale,
             )
 
-            opcheck(torch.ops._C.paged_attention_v2,
-                    (output, exp_sums, max_logits, tmp_output, query,
-                     key_cache, value_cache, num_kv_heads, scale, block_tables,
-                     seq_lens, block_size, max_seq_len, alibi_slopes,
-                     kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0),
-                    cond=(head_size == HEAD_SIZES[0]
-                          and block_size == BLOCK_SIZES[0]))
+            opcheck(
+                torch.ops._C.paged_attention_v2,
+                (output, exp_sums, max_logits, tmp_output, query, key_cache,
+                 value_cache, num_kv_heads, scale, block_tables, seq_lens,
+                 block_size, max_seq_len, alibi_slopes, attn_bias,
+                 kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0),
+                cond=(head_size == HEAD_SIZES[0]
+                      and block_size == BLOCK_SIZES[0]))
 
         else:
             ops.paged_attention_rocm(
@@ -307,17 +331,10 @@ def test_paged_attention(
         value_cache = dequantized_value_cache
 
     ref_output = torch.empty_like(query)
-    ref_single_query_cached_kv_attention(
-        ref_output,
-        query,
-        num_queries_per_kv,
-        key_cache,
-        value_cache,
-        block_tables,
-        seq_lens,
-        scale,
-        alibi_slopes,
-    )
+    ref_single_query_cached_kv_attention(ref_output, query, num_queries_per_kv,
+                                         key_cache, value_cache, block_tables,
+                                         seq_lens, scale, alibi_slopes,
+                                         attn_bias_list)
 
     # NOTE(woosuk): Due to the kernel-level differences in the two
     # implementations, there is a small numerical difference in the two
diff --git a/tests/kernels/test_blocksparse_attention.py b/tests/kernels/test_blocksparse_attention.py
index e653d34d00e..bedff278f95 100644
--- a/tests/kernels/test_blocksparse_attention.py
+++ b/tests/kernels/test_blocksparse_attention.py
@@ -230,6 +230,7 @@ def test_paged_attention(
             block_size,
             max_seq_len,
             alibi_slopes,
+            None,  # TODO add custom bias
             kv_cache_dtype,
             k_scale,
             v_scale,
@@ -267,6 +268,7 @@ def test_paged_attention(
             block_size,
             max_seq_len,
             alibi_slopes,
+            None,
             kv_cache_dtype,
             k_scale,
             v_scale,
diff --git a/tests/models/encoder_decoder/language/conftest.py b/tests/models/encoder_decoder/language/conftest.py
new file mode 100644
index 00000000000..c318dcb783c
--- /dev/null
+++ b/tests/models/encoder_decoder/language/conftest.py
@@ -0,0 +1,148 @@
+# SPDX-License-Identifier: Apache-2.0
+from typing import Any, Dict, List, Optional, Type
+
+from transformers import AutoModelForSeq2SeqLM
+
+from ....conftest import (DecoderPromptType, ExplicitEncoderDecoderPrompt,
+                          HfRunner, VllmRunner)
+from ...utils import check_logprobs_close
+from .utils import vllm_to_hf_output
+
+
+def compare_hf_vllm_logprobs(
+        hf_runner: Type[HfRunner],
+        vllm_runner: Type[VllmRunner],
+        prompts: List[ExplicitEncoderDecoderPrompt[str, str]],
+        decoder_prompt_type: DecoderPromptType,
+        model: str,
+        *,
+        dtype: str,
+        max_tokens: int,
+        num_logprobs: int,
+        tensor_parallel_size: int,
+        distributed_executor_backend: Optional[str] = None,
+        vllm_runner_kwargs: Optional[Dict[str, Any]] = None,
+        hf_tokens_to_skip: int = 0) -> None:
+    '''
+    Test the provided model for a variety of encoder/decoder input prompts,
+    by validating it against corresponding HuggingFace (HF).
+
+    Arguments:
+
+    * hf_runner: HuggingFace (HF) test model runner
+    * vllm_runner: vLLM test model runner
+    * example_encoder_decoder_prompts: test fixture which provides a 
+                                       dictionary of dummy prompts
+    * model: the HF ID of the specific BART variant under test
+    * dtype: the tensor datatype to employ
+    * max_tokens
+    * num_logprobs
+    * decoder_prompt_type: key into the example_encoder_decoder_prompts
+                           dictionary; selects specific encoder/decoder
+                           prompt scenarios to test
+
+    A note on using HF BART as a baseline for validating vLLM BART,
+    specifically when the decoder prompt is None. 
+    
+    The HF GenerationMixin's default behavior is to force the first
+    decoded token to be <BOS> if the prompt does not already contain
+    <BOS> (this is accomplished using a logit
+    processor setting.)
+    
+    So when we use HF BART as our baseline for comparison, note that
+    when the user provides a request with a None decoder prompt
+    (i.e. a singleton encoder prompt, or else an explicit encoder/
+    decoder prompt with the decoder sub-prompt set to None), HF and
+    vLLM handle this in different ways:
+    
+    * HF will (1) tokenize the None prompt as an empty token-list, 
+      (2) append <decoder-start-token> to the beginning, yielding
+      [<decoder-start-token>], (3) pass this token list to the model, and
+      then (4) after computing logits during prefill, override the model
+      logits & force <BOS> to be the first generated token.
+    
+    * vLLM will (1) tokenize the None prompt as [<BOS>], (2) append decoder-
+      start-token to the beginning, yielding [<decoder-start-token><BOS>],
+      (3) pass these tokens to the model & proceed with generation.
+    
+    The net effect is that compared to vLLM, the list of HF *decoded* tokens
+    will contain one more initial <BOS> than the vLLM generated tokens,
+    because vLLM's <BOS> token is injected into the prompt rather than into
+    the generated output. This is in spite of the fact that overall, the
+    complete sequences (prompt + decoded tokens) produced by vLLM will match
+    HF.
+    
+    So when we use HF decoded token output to validate vLLM's decoded token
+    output, the testing process must account for the difference in decoded
+    token sequences between vLLM and HF specifically in the
+    decoder-prompt-is-None case. 
+    
+    One option is to disable the logit processor feature that forces the
+    <BOS> token to be decoded (forced_bos_token_id = None), eliminating
+    the problem entirely. However this is not "normal" BART usage.
+    
+    The other option is - only in the decoder-prompt-is-None case - to
+    discard the first decoded token from the HF output before comparing it
+    to vLLM.
+
+    To that end, `hf_tokens_to_skip` must be set to the number of HF decoded 
+    tokens to skip during the process of validating the vLLM decoded output.
+    '''
+
+    # NOTE: take care of the order. run vLLM first, and then run HF.
+    # vLLM needs a fresh new process without cuda initialization.
+    # if we run HF first, the cuda initialization will be done and it
+    # will hurt multiprocessing backend with fork method (the default).
+
+    # Note: currently encoder/decoder models are only compatible with
+    # enforce_eager=True. Normally this is not a problem because
+    # for encoder/decoder models vLLM will
+    # default to enforce_eager=True if enforce_eager
+    # is left unspecified. However, the
+    # VllmRunner test fixture (which wraps around the LLM class) defaults to
+    # enforce_eager=False (a behavior which a number of already-exisitng
+    # decoder-only unit tests expect), so when testing an encoder/decoder
+    # model we must explicitly specify enforce_eager=True in the VllmRunner
+    # constructor.
+    if not vllm_runner_kwargs:
+        vllm_runner_kwargs = dict()
+    with vllm_runner(model,
+                     dtype=dtype,
+                     tensor_parallel_size=tensor_parallel_size,
+                     distributed_executor_backend=distributed_executor_backend,
+                     enforce_eager=True,
+                     **vllm_runner_kwargs) as vllm_model:
+        vllm_outputs = vllm_model.generate_encoder_decoder_greedy_logprobs(
+            prompts, max_tokens, num_logprobs)
+
+    # Configuration settings for HF baseline
+    hf_kwargs = {
+        "top_k": None,
+        "num_beams": 1,
+        "repetition_penalty": 1.0,
+        "top_p": 1.0,
+        "length_penalty": 1.0,
+        "early_stopping": False,
+        "no_repeat_ngram_size": None,
+        "min_length": 0
+    }
+
+    with hf_runner(model, dtype=dtype,
+                   auto_cls=AutoModelForSeq2SeqLM) as hf_model:
+        hf_outputs = (hf_model.generate_encoder_decoder_greedy_logprobs_limit(
+            prompts,
+            max_tokens,
+            num_logprobs,
+            **hf_kwargs,
+        ))
+
+    check_logprobs_close(
+        outputs_0_lst=hf_outputs,
+        outputs_1_lst=[
+            vllm_to_hf_output(vllm_output, decoder_prompt_type)
+            for vllm_output in vllm_outputs
+        ],
+        name_0="hf",
+        name_1="vllm",
+        num_outputs_0_skip_tokens=hf_tokens_to_skip,
+    )
diff --git a/tests/models/encoder_decoder/language/language/__init__.py b/tests/models/encoder_decoder/language/language/__init__.py
new file mode 100644
index 00000000000..e69de29bb2d
diff --git a/tests/models/encoder_decoder/language/test_bart.py b/tests/models/encoder_decoder/language/test_bart.py
index 81b629fdcf1..69a9a2a99e9 100644
--- a/tests/models/encoder_decoder/language/test_bart.py
+++ b/tests/models/encoder_decoder/language/test_bart.py
@@ -3,170 +3,12 @@
 
 Run `pytest tests/models/encoder_decoder/language/test_bart.py`.
 """
-from typing import List, Optional, Tuple, Type
-
 import pytest
-from transformers import AutoModelForSeq2SeqLM
-
-from vllm.sequence import SampleLogprobs
-
-from ....conftest import (DecoderPromptType, ExplicitEncoderDecoderPrompt,
-                          HfRunner, VllmRunner)
-from ....utils import multi_gpu_test
-from ...utils import check_logprobs_close
-
-
-def vllm_to_hf_output(
-    vllm_output: Tuple[List[int], str, Optional[SampleLogprobs]],
-    decoder_prompt_type: DecoderPromptType,
-):
-    """Sanitize vllm output to be comparable with hf output."""
-    output_ids, output_str, out_logprobs = vllm_output
-
-    hf_output_str = output_str + "</s>"
-    if decoder_prompt_type == DecoderPromptType.NONE:
-        hf_output_str = "<s>" + hf_output_str
-
-    return output_ids, hf_output_str, out_logprobs
-
-
-def run_test(
-    hf_runner: Type[HfRunner],
-    vllm_runner: Type[VllmRunner],
-    prompts: List[ExplicitEncoderDecoderPrompt[str, str]],
-    decoder_prompt_type: DecoderPromptType,
-    model: str,
-    *,
-    dtype: str,
-    max_tokens: int,
-    num_logprobs: int,
-    tensor_parallel_size: int,
-    distributed_executor_backend: Optional[str] = None,
-) -> None:
-    '''
-    Test the vLLM BART model for a variety of encoder/decoder input prompts,
-    by validating it against HuggingFace (HF) BART.
-
-    Arguments:
 
-    * hf_runner: HuggingFace (HF) test model runner
-    * vllm_runner: vLLM test model runner
-    * example_encoder_decoder_prompts: test fixture which provides a 
-                                       dictionary of dummy prompts
-    * model: the HF ID of the specific BART variant under test
-    * dtype: the tensor datatype to employ
-    * max_tokens
-    * num_logprobs
-    * decoder_prompt_type: key into the example_encoder_decoder_prompts
-                           dictionary; selects specific encoder/decoder
-                           prompt scenarios to test
+from tests.utils import multi_gpu_test
 
-    A note on using HF BART as a baseline for validating vLLM BART,
-    specifically when the decoder prompt is None. 
-    
-    The HF GenerationMixin's default behavior is to force the first
-    decoded token to be <BOS> if the prompt does not already contain
-    <BOS> (this is accomplished using a logit
-    processor setting.)
-    
-    So when we use HF BART as our baseline for comparison, note that
-    when the user provides a request with a None decoder prompt
-    (i.e. a singleton encoder prompt, or else an explicit encoder/
-    decoder prompt with the decoder sub-prompt set to None), HF and
-    vLLM handle this in different ways:
-    
-    * HF will (1) tokenize the None prompt as an empty token-list, 
-      (2) append <decoder-start-token> to the beginning, yielding
-      [<decoder-start-token>], (3) pass this token list to the model, and
-      then (4) after computing logits during prefill, override the model
-      logits & force <BOS> to be the first generated token.
-    
-    * vLLM will (1) tokenize the None prompt as [<BOS>], (2) append decoder-
-      start-token to the beginning, yielding [<decoder-start-token><BOS>],
-      (3) pass these tokens to the model & proceed with generation.
-    
-    The net effect is that compared to vLLM, the list of HF *decoded* tokens
-    will contain one more initial <BOS> than the vLLM generated tokens,
-    because vLLM's <BOS> token is injected into the prompt rather than into
-    the generated output. This is in spite of the fact that overall, the
-    complete sequences (prompt + decoded tokens) produced by vLLM will match
-    HF.
-    
-    So when we use HF decoded token output to validate vLLM's decoded token
-    output, the testing process must account for the difference in decoded
-    token sequences between vLLM and HF specifically in the
-    decoder-prompt-is-None case. 
-    
-    One option is to disable the logit processor feature that forces the
-    <BOS> token to be decoded (forced_bos_token_id = None), eliminating
-    the problem entirely. However this is not "normal" BART usage.
-    
-    The other option is - only in the decoder-prompt-is-None case - to
-    discard the first decoded token from the HF output before comparing it
-    to vLLM.
-
-    To that end, when testing the scenario where the decoder prompt is None
-    (and only in that one scenario), this test skips the first HF decoded
-    token during the process of validating the vLLM decoded output.
-    '''
-
-    # NOTE: take care of the order. run vLLM first, and then run HF.
-    # vLLM needs a fresh new process without cuda initialization.
-    # if we run HF first, the cuda initialization will be done and it
-    # will hurt multiprocessing backend with fork method (the default).
-
-    # Note: currently encoder/decoder models are only compatible with
-    # enforce_eager=True. Normally this is not a problem because
-    # for encoder/decoder models vLLM will
-    # default to enforce_eager=True if enforce_eager
-    # is left unspecified. However, the
-    # VllmRunner test fixture (which wraps around the LLM class) defaults to
-    # enforce_eager=False (a behavior which a number of already-exisitng
-    # decoder-only unit tests expect), so when testing an encoder/decoder
-    # model we must explicitly specify enforce_eager=True in the VllmRunner
-    # constructor.
-    with vllm_runner(model,
-                     dtype=dtype,
-                     tensor_parallel_size=tensor_parallel_size,
-                     distributed_executor_backend=distributed_executor_backend,
-                     enforce_eager=True) as vllm_model:
-        vllm_outputs = vllm_model.generate_encoder_decoder_greedy_logprobs(
-            prompts, max_tokens, num_logprobs)
-
-    # Configuration settings for HF baseline
-    hf_kwargs = {
-        "top_k": None,
-        "num_beams": 1,
-        "repetition_penalty": 1.0,
-        "top_p": 1.0,
-        "length_penalty": 1.0,
-        "early_stopping": False,
-        "no_repeat_ngram_size": None,
-        "min_length": 0
-    }
-
-    with hf_runner(model, dtype=dtype,
-                   auto_cls=AutoModelForSeq2SeqLM) as hf_model:
-        hf_outputs = (hf_model.generate_encoder_decoder_greedy_logprobs_limit(
-            prompts,
-            max_tokens,
-            num_logprobs,
-            **hf_kwargs,
-        ))
-
-    hf_skip_tokens = (1
-                      if decoder_prompt_type == DecoderPromptType.NONE else 0)
-
-    check_logprobs_close(
-        outputs_0_lst=hf_outputs,
-        outputs_1_lst=[
-            vllm_to_hf_output(vllm_output, decoder_prompt_type)
-            for vllm_output in vllm_outputs
-        ],
-        name_0="hf",
-        name_1="vllm",
-        num_outputs_0_skip_tokens=hf_skip_tokens,
-    )
+from ....conftest import DecoderPromptType
+from .conftest import compare_hf_vllm_logprobs
 
 
 @pytest.mark.parametrize(
@@ -184,7 +26,7 @@ def run_test(
 def test_models(hf_runner, vllm_runner, example_encoder_decoder_prompts, model,
                 dtype, max_tokens, num_logprobs, decoder_prompt_type) -> None:
 
-    run_test(
+    compare_hf_vllm_logprobs(
         hf_runner,
         vllm_runner,
         example_encoder_decoder_prompts[decoder_prompt_type],
@@ -194,7 +36,7 @@ def test_models(hf_runner, vllm_runner, example_encoder_decoder_prompts, model,
         max_tokens=max_tokens,
         num_logprobs=num_logprobs,
         tensor_parallel_size=1,
-    )
+        hf_tokens_to_skip=int(decoder_prompt_type == DecoderPromptType.NONE))
 
 
 @multi_gpu_test(num_gpus=2)
@@ -209,7 +51,7 @@ def test_models_distributed(hf_runner, vllm_runner,
                             distributed_executor_backend, model, dtype,
                             max_tokens, num_logprobs,
                             decoder_prompt_type) -> None:
-    run_test(
+    compare_hf_vllm_logprobs(
         hf_runner,
         vllm_runner,
         example_encoder_decoder_prompts[decoder_prompt_type],
diff --git a/tests/models/encoder_decoder/language/test_t5.py b/tests/models/encoder_decoder/language/test_t5.py
new file mode 100644
index 00000000000..48f61210a73
--- /dev/null
+++ b/tests/models/encoder_decoder/language/test_t5.py
@@ -0,0 +1,74 @@
+# SPDX-License-Identifier: Apache-2.0
+"""Compare the outputs of HF and vLLM for T5 models using greedy sampling.
+Based on tests/models/encoder_decoder/language/test_bart.py.
+
+Run `pytest tests/models/encoder_decoder/language/test_t5.py`.
+"""
+import pytest
+
+from tests.utils import multi_gpu_test
+from vllm.attention.selector import (_Backend,
+                                     global_force_attn_backend_context_manager)
+
+from ....conftest import DecoderPromptType
+from .conftest import compare_hf_vllm_logprobs
+
+
+@pytest.mark.parametrize(
+    "model",
+    [
+        pytest.param("Finnish-NLP/ul2-tiny-nl6-finnish"),
+        # pytest.param("google/flan-t5-base"),
+    ],
+)
+@pytest.mark.parametrize("vllm_kwargs", [{"max_model_len": 512}])
+@pytest.mark.parametrize("dtype", ["float", "bfloat16"])
+@pytest.mark.parametrize("max_tokens", [64])
+@pytest.mark.parametrize("num_logprobs", [5])
+# TODO custom prompt here generate high entropy output, causing
+# differences in sampled tokens.
+@pytest.mark.parametrize("decoder_prompt_type",
+                         [DecoderPromptType.NONE, DecoderPromptType.EMPTY_STR])
+def test_models(hf_runner, vllm_runner, example_encoder_decoder_prompts, model,
+                dtype, max_tokens, num_logprobs, decoder_prompt_type,
+                vllm_kwargs) -> None:
+    # Model only supported on xformers backend as of now.
+    with global_force_attn_backend_context_manager(_Backend.XFORMERS):
+        compare_hf_vllm_logprobs(
+            hf_runner,
+            vllm_runner,
+            example_encoder_decoder_prompts[decoder_prompt_type],
+            decoder_prompt_type,
+            model,
+            dtype=dtype,
+            max_tokens=max_tokens,
+            num_logprobs=num_logprobs,
+            tensor_parallel_size=1,
+            vllm_runner_kwargs=vllm_kwargs)
+
+
+@multi_gpu_test(num_gpus=2)
+@pytest.mark.parametrize("distributed_executor_backend", ["ray", "mp"])
+@pytest.mark.parametrize("model", ["google/t5-small"])
+@pytest.mark.parametrize("dtype", ["float"])
+@pytest.mark.parametrize("max_tokens", [64])
+@pytest.mark.parametrize("num_logprobs", [5])
+@pytest.mark.parametrize("decoder_prompt_type", [DecoderPromptType.NONE])
+def test_models_distributed(hf_runner, vllm_runner,
+                            example_encoder_decoder_prompts,
+                            distributed_executor_backend, model, dtype,
+                            max_tokens, num_logprobs,
+                            decoder_prompt_type) -> None:
+    with global_force_attn_backend_context_manager(_Backend.XFORMERS):
+        compare_hf_vllm_logprobs(
+            hf_runner,
+            vllm_runner,
+            example_encoder_decoder_prompts[decoder_prompt_type],
+            decoder_prompt_type,
+            model,
+            dtype=dtype,
+            max_tokens=max_tokens,
+            num_logprobs=num_logprobs,
+            tensor_parallel_size=2,
+            distributed_executor_backend=distributed_executor_backend,
+        )
diff --git a/tests/models/encoder_decoder/language/utils.py b/tests/models/encoder_decoder/language/utils.py
new file mode 100644
index 00000000000..ae91b160686
--- /dev/null
+++ b/tests/models/encoder_decoder/language/utils.py
@@ -0,0 +1,20 @@
+# SPDX-License-Identifier: Apache-2.0
+from typing import List, Optional, Tuple
+
+from vllm.sequence import SampleLogprobs
+
+from ....conftest import DecoderPromptType
+
+
+def vllm_to_hf_output(
+    vllm_output: Tuple[List[int], str, Optional[SampleLogprobs]],
+    decoder_prompt_type: DecoderPromptType,
+):
+    """Sanitize vllm output to be comparable with hf output."""
+    output_ids, output_str, out_logprobs = vllm_output
+
+    hf_output_str = output_str + "</s>"
+    if decoder_prompt_type == DecoderPromptType.NONE:
+        hf_output_str = "<s>" + hf_output_str
+
+    return output_ids, hf_output_str, out_logprobs
diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py
index a6823501676..98d7c9f875c 100644
--- a/vllm/_custom_ops.py
+++ b/vllm/_custom_ops.py
@@ -49,6 +49,7 @@ def paged_attention_v1(
     block_size: int,
     max_seq_len: int,
     alibi_slopes: Optional[torch.Tensor],
+    attn_bias: Optional[torch.Tensor],
     kv_cache_dtype: str,
     k_scale: torch.Tensor,
     v_scale: torch.Tensor,
@@ -60,8 +61,8 @@ def paged_attention_v1(
 ) -> None:
     torch.ops._C.paged_attention_v1(
         out, query, key_cache, value_cache, num_kv_heads, scale, block_tables,
-        seq_lens, block_size, max_seq_len, alibi_slopes, kv_cache_dtype,
-        k_scale, v_scale, tp_rank, blocksparse_local_blocks,
+        seq_lens, block_size, max_seq_len, alibi_slopes, attn_bias,
+        kv_cache_dtype, k_scale, v_scale, tp_rank, blocksparse_local_blocks,
         blocksparse_vert_stride, blocksparse_block_size,
         blocksparse_head_sliding_step)
 
@@ -81,6 +82,7 @@ def paged_attention_v2(
     block_size: int,
     max_seq_len: int,
     alibi_slopes: Optional[torch.Tensor],
+    attn_bias: Optional[torch.Tensor],
     kv_cache_dtype: str,
     k_scale: torch.Tensor,
     v_scale: torch.Tensor,
@@ -93,7 +95,7 @@ def paged_attention_v2(
     torch.ops._C.paged_attention_v2(
         out, exp_sum, max_logits, tmp_out, query, key_cache, value_cache,
         num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len,
-        alibi_slopes, kv_cache_dtype, k_scale, v_scale, tp_rank,
+        alibi_slopes, attn_bias, kv_cache_dtype, k_scale, v_scale, tp_rank,
         blocksparse_local_blocks, blocksparse_vert_stride,
         blocksparse_block_size, blocksparse_head_sliding_step)
 
diff --git a/vllm/attention/backends/blocksparse_attn.py b/vllm/attention/backends/blocksparse_attn.py
index 9765e7881ad..ef8009a3cc8 100644
--- a/vllm/attention/backends/blocksparse_attn.py
+++ b/vllm/attention/backends/blocksparse_attn.py
@@ -443,6 +443,7 @@ def forward(
                 self.num_kv_heads,
                 self.scale,
                 self.alibi_slopes,
+                None,  # TODO support attn_bias
                 layer._k_scale,
                 layer._v_scale,
                 tp_rank=self.tp_rank,
diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py
index 02bff57a62b..8f3c94920be 100644
--- a/vllm/attention/backends/rocm_flash_attn.py
+++ b/vllm/attention/backends/rocm_flash_attn.py
@@ -837,6 +837,7 @@ def forward(
                     self.num_kv_heads,
                     self.scale,
                     self.alibi_slopes,
+                    None,  # TODO support attn_bias
                     layer._k_scale,
                     layer._v_scale,
                 )
diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py
index 723a4558d0b..196248682d7 100644
--- a/vllm/attention/backends/xformers.py
+++ b/vllm/attention/backends/xformers.py
@@ -159,8 +159,8 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
     def __post_init__(self):
         # Set during the execution of the first attention op.
         # It is a list because it is needed to set per prompt
-        # when alibi slopes is used. It is because of the limitation
-        # from xformer API.
+        # when alibi slopes or custom attention bias are used.
+        # It is because of a limitation from xformer API.
         # will not appear in the __repr__ and __init__
         self.attn_bias: Optional[List[AttentionBias]] = None
         self.encoder_attn_bias: Optional[List[AttentionBias]] = None
@@ -292,7 +292,7 @@ def decode_metadata(self) -> Optional["XFormersMetadata"]:
 def _get_attn_bias(
     attn_metadata: XFormersMetadata,
     attn_type: str,
-) -> Optional[AttentionBias]:
+) -> Optional[List[AttentionBias]]:
     '''
     Extract appropriate attention bias from attention metadata
     according to attention type.
@@ -548,6 +548,7 @@ def forward(
 
         assert query.shape[0] == num_prefill_query_tokens
         assert decode_query.shape[0] == num_decode_query_tokens
+        attn_bias = _get_attn_bias(attn_metadata, attn_type)
 
         if prefill_meta := attn_metadata.prefill_metadata:
             # Prompt run.
@@ -555,6 +556,10 @@ def forward(
                 # normal attention.
                 # block tables are empty if the prompt does not have a cached
                 # prefix.
+                # As prefill metadata are cached on first call, we need to make
+                # sure attn_bias is up to date.
+                if attn_bias:
+                    _set_attn_bias(prefill_meta, attn_bias, attn_type)
                 out = self._run_memory_efficient_xformers_forward(
                     query, key, value, prefill_meta, attn_type=attn_type)
                 assert out.shape == output[:num_prefill_query_tokens].shape
@@ -583,6 +588,7 @@ def forward(
                     prefill_meta.context_lens_tensor,
                     prefill_meta.max_query_len,
                     self.alibi_slopes,
+                    _get_attn_bias(attn_metadata, attn_type),
                     self.sliding_window,
                     layer._k_scale,
                     layer._v_scale,
@@ -600,6 +606,13 @@ def forward(
                 block_tables_arg,
             ) = get_seq_len_block_table_args(decode_meta, False, attn_type)
 
+            if attn_bias:
+                assert len(
+                    attn_bias
+                ) == 1, "PagedAttention expects a single bias to be provided\
+                      for all input sequences."
+
+                attn_bias = attn_bias[0]
             output[num_prefill_query_tokens:] = PagedAttention.forward_decode(
                 decode_query,
                 key_cache,
@@ -611,6 +624,7 @@ def forward(
                 self.num_kv_heads,
                 self.scale,
                 self.alibi_slopes,
+                attn_bias,
                 layer._k_scale,
                 layer._v_scale,
             )
@@ -706,6 +720,7 @@ def _run_memory_efficient_xformers_forward(
                 else:
                     raise ValueError("Unknown AttentionType: %s", attn_type)
 
+                assert isinstance(attn_bias, BlockDiagonalMask)
                 if self.sliding_window is not None:
                     attn_bias = attn_bias.make_local_attention(
                         self.sliding_window)
@@ -719,10 +734,10 @@ def _run_memory_efficient_xformers_forward(
 
             _set_attn_bias(attn_metadata, attn_bias, attn_type)
 
-        # No alibi slopes.
+        # No alibi slopes and no multi-sequence custom attention bias.
         # TODO(woosuk): Too many view operations. Let's try to reduce
         # them in the future for code readability.
-        if self.alibi_slopes is None:
+        if self.alibi_slopes is None and len(attn_bias) == 1:
             # Add the batch dimension.
             query = query.unsqueeze(0)
             key = key.unsqueeze(0)
@@ -736,14 +751,16 @@ def _run_memory_efficient_xformers_forward(
                 scale=self.scale)
             return out.view_as(original_query)
 
-        # Attention with alibi slopes.
+        # Attention with alibi slopes or multiple custom attention bias.
         # FIXME(woosuk): Because xformers does not support dynamic sequence
         # lengths with custom attention bias, we process each prompt one by
         # one. This is inefficient, especially when we have many short prompts.
-        assert attn_metadata.seq_lens is not None
         output = torch.empty_like(original_query)
+        seq_lens = attn_metadata.encoder_seq_lens \
+            if attn_type == AttentionType.ENCODER else attn_metadata.seq_lens
+        assert seq_lens
         start = 0
-        for i, seq_len in enumerate(attn_metadata.seq_lens):
+        for i, seq_len in enumerate(seq_lens):
             end = start + seq_len
             out = xops.memory_efficient_attention_forward(
                 query[None, start:end],
diff --git a/vllm/attention/ops/ipex_attn.py b/vllm/attention/ops/ipex_attn.py
index 598ceea130d..150c99fc97f 100644
--- a/vllm/attention/ops/ipex_attn.py
+++ b/vllm/attention/ops/ipex_attn.py
@@ -105,6 +105,7 @@ def forward_decode(
             block_size,
             max_context_len,
             alibi_slopes,
+            None,  # TODO add custom bias
             kv_cache_dtype,
             k_scale,
             v_scale,
diff --git a/vllm/attention/ops/paged_attn.py b/vllm/attention/ops/paged_attn.py
index 2c60bd0c38d..89ae554bff9 100644
--- a/vllm/attention/ops/paged_attn.py
+++ b/vllm/attention/ops/paged_attn.py
@@ -1,5 +1,4 @@
 # SPDX-License-Identifier: Apache-2.0
-
 from dataclasses import dataclass
 from typing import List, Optional, Tuple
 
@@ -97,6 +96,7 @@ def forward_decode(
         num_kv_heads: int,
         scale: float,
         alibi_slopes: Optional[torch.Tensor],
+        attn_bias: Optional[torch.Tensor],
         k_scale: torch.Tensor,
         v_scale: torch.Tensor,
         tp_rank: int = 0,
@@ -142,6 +142,7 @@ def forward_decode(
                 block_size,
                 max_seq_len,
                 alibi_slopes,
+                attn_bias,
                 kv_cache_dtype,
                 k_scale,
                 v_scale,
@@ -180,6 +181,7 @@ def forward_decode(
                 block_size,
                 max_seq_len,
                 alibi_slopes,
+                attn_bias,
                 kv_cache_dtype,
                 k_scale,
                 v_scale,
@@ -205,11 +207,13 @@ def forward_prefix(
         context_lens: torch.Tensor,
         max_query_len: int,
         alibi_slopes: Optional[torch.Tensor],
+        attn_bias: Optional[torch.Tensor],
         sliding_window: Optional[int],
         k_scale: torch.Tensor,
         v_scale: torch.Tensor,
     ) -> torch.Tensor:
         output = torch.empty_like(query)
+        assert attn_bias is None, "Bias for prefix not yet enabled"
         context_attention_fwd(
             query,
             key,
diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py
index 53f89996f0f..4a25a6c37b3 100644
--- a/vllm/inputs/preprocess.py
+++ b/vllm/inputs/preprocess.py
@@ -91,7 +91,7 @@ def get_decoder_start_token_id(self) -> Optional[int]:
 
         return dec_start_token_id
 
-    def _get_default_enc_dec_decoder_prompt(self) -> List[int]:
+    def _maybe_get_default_enc_dec_decoder_prompt(self) -> List[int]:
         '''
         Specifically for encoder/decoder models:
         generate a default decoder prompt for when
@@ -124,8 +124,7 @@ def _get_default_enc_dec_decoder_prompt(self) -> List[int]:
         '''
 
         bos_token_id = self.get_bos_token_id()
-        assert bos_token_id is not None
-        return [bos_token_id]
+        return [] if bos_token_id is None else [bos_token_id]
 
     def _prepare_decoder_input_ids_for_generation(
         self,
@@ -157,7 +156,8 @@ def _prepare_decoder_input_ids_for_generation(
         if decoder_input_ids is None:
             # no decoder prompt input ->
             # use decoder_start_token_id as decoder_input_ids
-            decoder_input_ids = self._get_default_enc_dec_decoder_prompt()
+            decoder_input_ids = self._maybe_get_default_enc_dec_decoder_prompt(
+            )
 
         if (len(decoder_input_ids) == 0
                 or decoder_input_ids[0] != decoder_start_token_id):
diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py
index c2d0fae7056..51c04737bb5 100644
--- a/vllm/model_executor/models/registry.py
+++ b/vllm/model_executor/models/registry.py
@@ -104,6 +104,8 @@
     "BartModel": ("bart", "BartForConditionalGeneration"),
     "BartForConditionalGeneration": ("bart", "BartForConditionalGeneration"),
     "Florence2ForConditionalGeneration": ("florence2", "Florence2ForConditionalGeneration"),  # noqa: E501
+    "T5Model": ("t5", "T5ForConditionalGeneration"),
+    "T5ForConditionalGeneration": ("t5", "T5ForConditionalGeneration"),
 }
 
 _EMBEDDING_MODELS = {
diff --git a/vllm/model_executor/models/t5.py b/vllm/model_executor/models/t5.py
new file mode 100644
index 00000000000..70a2f2a726e
--- /dev/null
+++ b/vllm/model_executor/models/t5.py
@@ -0,0 +1,790 @@
+# SPDX-License-Identifier: Apache-2.0
+# Derived from T5 implementation posted on HuggingFace; license below:
+#
+# coding=utf-8
+# Copyright 2018 Mesh TensorFlow authors, T5 Authors and HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch T5 model."""
+
+import math
+import re
+from typing import Iterable, List, Optional, Set, Tuple
+
+import torch
+from torch import nn
+from transformers import T5Config
+# TODO best way to handle xformers imports?
+from xformers.ops.fmha.attn_bias import LowerTriangularMaskWithTensorBias
+
+# TODO func should be in backend interface
+from vllm.attention.backends.xformers import (XFormersMetadata, _get_attn_bias,
+                                              _set_attn_bias)
+from vllm.attention.layer import Attention, AttentionMetadata, AttentionType
+from vllm.config import CacheConfig, VllmConfig
+from vllm.distributed import get_tensor_model_parallel_world_size
+from vllm.model_executor.layers.activation import get_act_fn
+from vllm.model_executor.layers.linear import (ColumnParallelLinear,
+                                               QKVParallelLinear,
+                                               RowParallelLinear)
+from vllm.model_executor.layers.logits_processor import LogitsProcessor
+from vllm.model_executor.layers.quantization.base_config import (
+    QuantizationConfig)
+from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
+from vllm.model_executor.layers.vocab_parallel_embedding import (
+    ParallelLMHead, VocabParallelEmbedding)
+from vllm.model_executor.model_loader.weight_utils import default_weight_loader
+from vllm.model_executor.sampling_metadata import SamplingMetadata
+from vllm.sequence import IntermediateTensors
+
+from .utils import maybe_prefix
+
+
+class T5LayerNorm(nn.Module):
+
+    def __init__(self, hidden_size, eps=1e-6):
+        """
+        Construct a layernorm module in the T5 style. 
+        No bias and no subtraction of mean.
+        """
+        super().__init__()
+        self.weight = nn.Parameter(torch.ones(hidden_size))
+        self.variance_epsilon = eps
+
+    def forward(self, hidden_states) -> torch.Tensor:
+        # T5 uses a layer_norm which only scales and doesn't shift, which is
+        # also known as Root Mean Square Layer Normalization
+        # https://arxiv.org/abs/1910.07467 thus variance is calculated w/o mean
+        # and there is no bias. Additionally we want to make sure that the
+        # accumulation for half-precision inputs is done in fp32.
+
+        variance = hidden_states.to(torch.float32).pow(2).mean(-1,
+                                                               keepdim=True)
+        hidden_states = hidden_states * torch.rsqrt(variance +
+                                                    self.variance_epsilon)
+
+        # convert into half-precision if necessary
+        if self.weight.dtype in [torch.float16, torch.bfloat16]:
+            hidden_states = hidden_states.to(self.weight.dtype)
+
+        return self.weight * hidden_states
+
+
+class T5DenseActDense(nn.Module):
+
+    def __init__(self,
+                 config: T5Config,
+                 quant_config: Optional[QuantizationConfig] = None):
+        super().__init__()
+        self.wi = ColumnParallelLinear(config.d_model, config.d_ff, bias=False)
+        self.wo = RowParallelLinear(config.d_ff,
+                                    config.d_model,
+                                    bias=False,
+                                    quant_config=quant_config)
+        self.act = get_act_fn(config.dense_act_fn)
+
+    def forward(self, hidden_states) -> torch.Tensor:
+        hidden_states, _ = self.wi(hidden_states)
+        hidden_states = self.act(hidden_states)
+        # if (
+        #     isinstance(self.wo.weight, torch.Tensor)
+        #     and hidden_states.dtype != self.wo.weight.dtype
+        #     and self.wo.weight.dtype != torch.int8
+        # ):
+        #     hidden_states = hidden_states.to(self.wo.weight.dtype)
+        hidden_states, _ = self.wo(hidden_states)
+        return hidden_states
+
+
+class T5DenseGatedActDense(nn.Module):
+
+    def __init__(self,
+                 config: T5Config,
+                 quant_config: Optional[QuantizationConfig] = None):
+        super().__init__()
+        self.wi_0 = ColumnParallelLinear(config.d_model,
+                                         config.d_ff,
+                                         bias=False,
+                                         quant_config=quant_config)
+        self.wi_1 = ColumnParallelLinear(config.d_model,
+                                         config.d_ff,
+                                         bias=False,
+                                         quant_config=quant_config)
+        # Should not run in fp16 unless mixed-precision is used,
+        # see https://github.com/huggingface/transformers/issues/20287.
+        self.wo = RowParallelLinear(config.d_ff,
+                                    config.d_model,
+                                    bias=False,
+                                    quant_config=quant_config)
+        self.act = get_act_fn(config.dense_act_fn)
+
+    def forward(self, hidden_states) -> torch.Tensor:
+        hidden_gelu = self.act(self.wi_0(hidden_states)[0])
+        hidden_linear, _ = self.wi_1(hidden_states)
+        hidden_states = hidden_gelu * hidden_linear
+        hidden_states, _ = self.wo(hidden_states)
+        return hidden_states
+
+
+class T5LayerFF(nn.Module):
+
+    def __init__(self,
+                 config: T5Config,
+                 quant_config: Optional[QuantizationConfig] = None):
+        super().__init__()
+        if config.is_gated_act:
+            self.DenseReluDense = T5DenseGatedActDense(
+                config, quant_config=quant_config)
+        else:
+            self.DenseReluDense = T5DenseActDense(config,
+                                                  quant_config=quant_config)
+
+        self.layer_norm = T5LayerNorm(config.d_model,
+                                      eps=config.layer_norm_epsilon)
+
+    def forward(self, hidden_states) -> torch.Tensor:
+        forwarded_states = self.layer_norm(hidden_states)
+        forwarded_states = self.DenseReluDense(forwarded_states)
+        hidden_states = hidden_states + forwarded_states
+        return hidden_states
+
+
+class T5Attention(nn.Module):
+
+    def __init__(self,
+                 config: T5Config,
+                 attn_type: AttentionType,
+                 has_relative_attention_bias=False,
+                 cache_config: Optional[CacheConfig] = None,
+                 quant_config: Optional[QuantizationConfig] = None,
+                 prefix: str = ""):
+        super().__init__()
+        self.attn_type = attn_type
+        # Cross-attention has no relative pos encoding anyway
+        self.is_decoder = attn_type == AttentionType.DECODER
+        self.has_relative_attention_bias = has_relative_attention_bias
+        self.relative_attention_num_buckets = \
+            config.relative_attention_num_buckets
+        self.relative_attention_max_distance = \
+            config.relative_attention_max_distance
+        self.d_model = config.d_model
+        self.key_value_proj_dim = config.d_kv
+        assert cache_config
+        # Alternatively we can get it from kv_cache size in fwd.
+        self.block_size = cache_config.block_size
+
+        # Partition heads across multiple tensor parallel GPUs.
+        tp_world_size = get_tensor_model_parallel_world_size()
+        assert config.num_heads % tp_world_size == 0
+        self.n_heads = config.num_heads // tp_world_size
+
+        self.inner_dim = self.n_heads * self.key_value_proj_dim
+        # No GQA in t5.
+        self.n_kv_heads = self.n_heads
+
+        self.qkv_proj = QKVParallelLinear(self.d_model,
+                                          self.d_model // self.n_heads,
+                                          self.n_heads,
+                                          self.n_kv_heads,
+                                          bias=False,
+                                          quant_config=quant_config)
+
+        # NOTE (NickLucche) T5 employs a scaled weight initialization scheme
+        # instead of scaling attention scores directly.
+        self.attn = Attention(self.n_heads,
+                              config.d_kv,
+                              1.0,
+                              cache_config=cache_config,
+                              quant_config=quant_config,
+                              prefix=f"{prefix}.attn",
+                              attn_type=self.attn_type)
+
+        # Only the first SelfAttention block in encoder decoder has this
+        # embedding layer, the others reuse its output.
+        if self.has_relative_attention_bias:
+            self.relative_attention_bias = \
+                VocabParallelEmbedding(self.relative_attention_num_buckets,
+                                       self.n_heads,
+                                       org_num_embeddings=\
+                                        self.relative_attention_num_buckets,
+                                       quant_config=quant_config)
+        self.out_proj = RowParallelLinear(
+            self.inner_dim,
+            self.d_model,
+            bias=False,
+            quant_config=quant_config,
+        )
+
+    @staticmethod
+    def _relative_position_bucket(relative_position,
+                                  bidirectional=True,
+                                  num_buckets=32,
+                                  max_distance=128):
+        """
+        Adapted from Mesh Tensorflow:
+        https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
+
+        Translate relative position to a bucket number for relative attention. 
+        The relative position is defined as memory_position - query_position, 
+        i.e. the distance in tokens from the attending position to the 
+        attended-to position. If bidirectional=False, then positive relative 
+        positions are invalid. We use smaller buckets for small absolute 
+        relative_position and larger buckets for larger absolute 
+        relative_positions. All relative positions >=max_distance map to the
+        same bucket. All relative positions <=-max_distance map to the same 
+        bucket. This should allow for more graceful generalization to longer
+        sequences than the model has been trained on
+
+        Args:
+            relative_position: an int32 Tensor
+            bidirectional: a boolean - whether the attention is bidirectional
+            num_buckets: an integer
+            max_distance: an integer
+
+        Returns:
+            a Tensor with the same shape as relative_position, containing int32
+            values in the range [0, num_buckets)
+        """# noqa: E501
+        relative_buckets = 0
+        if bidirectional:
+            num_buckets //= 2
+            relative_buckets += (relative_position > 0).to(
+                torch.long) * num_buckets
+            relative_position = torch.abs(relative_position)
+        else:
+            relative_position = -torch.min(relative_position,
+                                           torch.zeros_like(relative_position))
+        # now relative_position is in the range [0, inf)
+
+        # half of the buckets are for exact increments in positions
+        max_exact = num_buckets // 2
+        is_small = relative_position < max_exact
+
+        # The other half of the buckets are for logarithmically bigger bins
+        # in positions up to max_distance
+        relative_position_if_large = max_exact + (
+            torch.log(relative_position.float() / max_exact) /
+            math.log(max_distance / max_exact) *
+            (num_buckets - max_exact)).to(torch.long)
+        relative_position_if_large = torch.min(
+            relative_position_if_large,
+            torch.full_like(relative_position_if_large, num_buckets - 1))
+
+        relative_buckets += torch.where(is_small, relative_position,
+                                        relative_position_if_large)
+        return relative_buckets
+
+    def compute_bias(self,
+                     query_length,
+                     key_length,
+                     device=None) -> torch.Tensor:
+        """Compute binned relative position bias"""
+        # TODO possible tp issue?
+        if device is None:
+            device = self.relative_attention_bias.weight.device
+        context_position = torch.arange(query_length,
+                                        dtype=torch.long,
+                                        device=device)[:, None]
+        memory_position = torch.arange(key_length,
+                                       dtype=torch.long,
+                                       device=device)[None, :]
+        # max_seq_len, nh
+        relative_position = memory_position - context_position
+        relative_position_bucket = self._relative_position_bucket(
+            relative_position,  # shape (query_length, key_length)
+            bidirectional=(not self.is_decoder),
+            num_buckets=self.relative_attention_num_buckets,
+            max_distance=self.relative_attention_max_distance,
+        )
+        values = self.relative_attention_bias(
+            relative_position_bucket
+        )  # shape (query_length, key_length, num_heads)
+        x = values.permute([2, 0, 1]).unsqueeze(
+            0)  # shape (1, num_heads, query_length, key_length)
+        return x
+
+    def forward(
+        self,
+        hidden_states: torch.Tensor,  # (num_tokens, d_model)
+        kv_cache: torch.Tensor,
+        attn_metadata: AttentionMetadata,
+        encoder_hidden_states: Optional[torch.Tensor] = None,
+    ) -> torch.Tensor:
+        # TODO auto-selection of xformers backend when t5 is detected
+        assert isinstance(attn_metadata, XFormersMetadata)
+        num_seqs = len(
+            attn_metadata.seq_lens) if attn_metadata.seq_lens else len(
+                attn_metadata.encoder_seq_lens)
+        qkv, _ = self.qkv_proj(hidden_states)
+        # Projection of 'own' hidden state (self-attention). No GQA here.
+        q, k, v = qkv.split(self.inner_dim, dim=-1)
+
+        # NOTE (NickLucche) Attn bias is computed once per encoder or decoder
+        # forward, on the first call to T5Attention.forward. Subsequent
+        # *self-attention* layers will reuse it.
+        attn_bias = _get_attn_bias(attn_metadata, self.attn_type)
+        if self.attn_type == AttentionType.ENCODER_DECODER:
+            # Projection of encoder's hidden states, cross-attention.
+            if encoder_hidden_states is None:
+                # Decode phase, kv already cached
+                assert attn_metadata.num_prefills == 0
+                k = None
+                v = None
+            else:
+                assert attn_metadata.num_prefills > 0
+                # Prefill phase (first decoder forward), caching kv
+                qkv_enc, _ = self.qkv_proj(encoder_hidden_states)
+                _, k, v = qkv_enc.split(self.inner_dim, dim=-1)
+            # No custom attention bias must be set when running cross attn.
+            assert attn_bias is None
+
+        # Not compatible with CP here (as all encoder-decoder models),
+        # as it assumes homogeneous batch (prefills or decodes).
+        elif self.has_relative_attention_bias:
+            assert attn_bias is None  # to be recomputed
+            # Self-attention. Compute T5 relative positional encoding.
+            # The bias term is computed on longest sequence in batch. Biases
+            # for shorter sequences are slices of the longest.
+            # TODO xformers-specific code.
+            align_to = 8
+            # bias expected shape: (num_seqs, NH, L, L_pad) for prefill,
+            #                      (num_seqs, NH, 1, L_pad) for decodes.
+            if self.attn_type == AttentionType.ENCODER:
+                # Encoder prefill stage, uses xFormers, hence sequence
+                # padding/alignment to 8 is required.
+                seq_len = attn_metadata.max_encoder_seq_len
+                padded_seq_len = (seq_len + align_to -
+                                  1) // align_to * align_to
+                # TODO (NickLucche) avoid extra copy on repeat,
+                # provide multiple slices of same memory
+                position_bias = self.compute_bias(seq_len,
+                                                  padded_seq_len).repeat(
+                                                      num_seqs, 1, 1, 1)
+                # xFormers expects a list of biases, one matrix per sequence.
+                # As each sequence gets its own bias, no masking is required.
+                attn_bias = [
+                    p[None, :, :sq, :sq] for p, sq in zip(
+                        position_bias, attn_metadata.encoder_seq_lens)
+                ]
+            elif attn_metadata.prefill_metadata:
+                # Decoder prefill stage, uses xFormers, hence sequence
+                # padding/alignment to 8 is required. First decoder step,
+                # seq_len is usually 1, but one can prepend different start
+                # tokens prior to generation.
+                seq_len = attn_metadata.max_prefill_seq_len
+                # ->align
+                padded_seq_len = (seq_len + align_to -
+                                  1) // align_to * align_to
+                position_bias = self.compute_bias(seq_len,
+                                                  padded_seq_len).repeat(
+                                                      num_seqs, 1, 1, 1)
+                # Causal mask for prefill.
+                attn_bias = [
+                    LowerTriangularMaskWithTensorBias(pb[None, :, :sq, :sq])
+                    for pb, sq in zip(position_bias, attn_metadata.seq_lens)
+                ]
+            else:
+                # Decoder decoding stage, uses PagedAttention, hence sequence
+                # padding/alignment to `block_size` is required. Expected
+                # number of queries is always 1 (MQA not supported).
+                seq_len = attn_metadata.max_decode_seq_len
+                block_aligned_seq_len = (seq_len + self.block_size - 1
+                                         ) // self.block_size * self.block_size
+
+                # TODO bf16 bias support in PagedAttention.
+                position_bias = self.compute_bias(
+                    seq_len, block_aligned_seq_len).float()
+                # Bias for the last query, the one at current decoding step.
+                position_bias = position_bias[:, :, -1:, :].repeat(
+                    num_seqs, 1, 1, 1)
+                # No explicit masking required, this is done inside the
+                # paged attention kernel based on the sequence length.
+                attn_bias = [position_bias]
+
+            # NOTE Assign bias term on metadata based on attn type:
+            # ENCODER->`encoder_attn_bias`, DECODER->`attn_bias`.
+            _set_attn_bias(attn_metadata, attn_bias, self.attn_type)
+        elif not self.has_relative_attention_bias:
+            # Encoder/Decoder Self-Attention Layer, attn bias already cached.
+            assert attn_bias is not None
+
+        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
+        output, _ = self.out_proj(attn_output)
+        return output
+
+
+class T5LayerSelfAttention(nn.Module):
+
+    def __init__(
+        self,
+        config,
+        has_relative_attention_bias=False,
+        cache_config: Optional[CacheConfig] = None,
+        quant_config: Optional[QuantizationConfig] = None,
+        prefix: str = "",
+    ):
+        super().__init__()
+        self.SelfAttention = T5Attention(
+            config,
+            AttentionType.DECODER
+            if "decoder" in prefix else AttentionType.ENCODER,
+            has_relative_attention_bias=has_relative_attention_bias,
+            cache_config=cache_config,
+            quant_config=quant_config,
+            prefix=f"{prefix}.SelfAttention")
+        self.layer_norm = T5LayerNorm(config.d_model,
+                                      eps=config.layer_norm_epsilon)
+
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        kv_cache: torch.Tensor,
+        attn_metadata: AttentionMetadata,
+    ) -> torch.Tensor:
+        normed_hidden_states = self.layer_norm(hidden_states)
+        attention_output = self.SelfAttention(
+            hidden_states=normed_hidden_states,
+            kv_cache=kv_cache,
+            attn_metadata=attn_metadata,
+            encoder_hidden_states=None,
+        )
+        hidden_states = hidden_states + attention_output
+        return hidden_states
+
+
+class T5LayerCrossAttention(nn.Module):
+
+    def __init__(self,
+                 config,
+                 cache_config: Optional[CacheConfig] = None,
+                 quant_config: Optional[QuantizationConfig] = None,
+                 prefix: str = ""):
+        super().__init__()
+        self.EncDecAttention = T5Attention(config,
+                                           AttentionType.ENCODER_DECODER,
+                                           has_relative_attention_bias=False,
+                                           cache_config=cache_config,
+                                           quant_config=quant_config,
+                                           prefix=f"{prefix}.EncDecAttention")
+        self.layer_norm = T5LayerNorm(config.d_model,
+                                      eps=config.layer_norm_epsilon)
+
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        kv_cache: torch.Tensor,
+        attn_metadata: AttentionMetadata,
+        encoder_hidden_states: Optional[torch.Tensor] = None,
+    ) -> torch.Tensor:
+        normed_hidden_states = self.layer_norm(hidden_states)
+        attention_output = self.EncDecAttention(
+            hidden_states=normed_hidden_states,
+            kv_cache=kv_cache,
+            attn_metadata=attn_metadata,
+            encoder_hidden_states=encoder_hidden_states,
+        )
+        hidden_states = hidden_states + attention_output
+        return hidden_states
+
+
+class T5Block(nn.Module):
+
+    def __init__(self,
+                 config: T5Config,
+                 is_decoder: bool,
+                 has_relative_attention_bias=False,
+                 cache_config: Optional[CacheConfig] = None,
+                 quant_config: Optional[QuantizationConfig] = None,
+                 prefix: str = ""):
+        super().__init__()
+        self.is_decoder = is_decoder
+        self.self_attn = T5LayerSelfAttention(
+            config,
+            has_relative_attention_bias=has_relative_attention_bias,
+            cache_config=cache_config,
+            quant_config=quant_config,
+            prefix=f"{prefix}.self_attn")
+
+        if self.is_decoder:
+            self.cross_attn = T5LayerCrossAttention(
+                config,
+                cache_config=cache_config,
+                quant_config=quant_config,
+                prefix=f"{prefix}.cross_attn")
+
+        self.ffn = T5LayerFF(config, quant_config=quant_config)
+
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        kv_cache: torch.Tensor,
+        attn_metadata: AttentionMetadata,
+        encoder_hidden_states: Optional[torch.Tensor] = None,
+    ) -> torch.Tensor:
+
+        hidden_states = self.self_attn(
+            hidden_states=hidden_states,
+            kv_cache=kv_cache,
+            attn_metadata=attn_metadata,
+        )
+        if self.is_decoder:
+            hidden_states = self.cross_attn(
+                hidden_states=hidden_states,
+                kv_cache=kv_cache,
+                attn_metadata=attn_metadata,
+                encoder_hidden_states=encoder_hidden_states,
+            )
+
+        # Apply Feed Forward layer
+        hidden_states = self.ffn(hidden_states)
+        return hidden_states
+
+
+class T5Stack(nn.Module):
+
+    def __init__(self,
+                 config: T5Config,
+                 is_decoder: bool,
+                 n_layers: int,
+                 embed_tokens=None,
+                 cache_config: Optional[CacheConfig] = None,
+                 quant_config: Optional[QuantizationConfig] = None,
+                 prefix: str = ""):
+        super().__init__()
+        self.embed_tokens = embed_tokens
+        # Only the first block has relative positional encoding.
+        self.blocks = nn.ModuleList([
+            T5Block(config,
+                    is_decoder=is_decoder,
+                    has_relative_attention_bias=i == 0,
+                    cache_config=cache_config,
+                    quant_config=quant_config,
+                    prefix=f"{prefix}.blocks.{i}") for i in range(n_layers)
+        ])
+        self.final_layer_norm = T5LayerNorm(config.d_model,
+                                            eps=config.layer_norm_epsilon)
+
+    def forward(
+            self,
+            input_ids: torch.Tensor,
+            kv_caches: List[torch.Tensor],
+            attn_metadata: AttentionMetadata,
+            encoder_hidden_states: Optional[torch.Tensor] = None
+    ) -> torch.Tensor:
+        hidden_states = self.embed_tokens(input_ids)
+
+        for idx, block in enumerate(self.blocks):
+            hidden_states = block(
+                hidden_states=hidden_states,
+                kv_cache=kv_caches[idx],
+                attn_metadata=attn_metadata,
+                encoder_hidden_states=encoder_hidden_states,
+            )
+        hidden_states = self.final_layer_norm(hidden_states)
+        return hidden_states
+
+
+class T5Model(nn.Module):
+    _tied_weights_keys = [
+        "encoder.embed_tokens.weight", "decoder.embed_tokens.weight"
+    ]
+
+    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
+        super().__init__()
+        config: T5Config = vllm_config.model_config.hf_config
+        cache_config = vllm_config.cache_config
+        quant_config = vllm_config.quant_config
+        lora_config = vllm_config.lora_config
+
+        lora_vocab = (lora_config.lora_extra_vocab_size *
+                      (lora_config.max_loras or 1)) if lora_config else 0
+        self.vocab_size = config.vocab_size + lora_vocab
+        self.padding_idx = config.pad_token_id
+        self.shared = VocabParallelEmbedding(
+            config.vocab_size,
+            config.d_model,
+            org_num_embeddings=config.vocab_size)
+
+        self.encoder = T5Stack(config,
+                               False,
+                               config.num_layers,
+                               self.shared,
+                               cache_config=cache_config,
+                               quant_config=quant_config,
+                               prefix=f"{prefix}.encoder")
+        self.decoder = T5Stack(config,
+                               True,
+                               config.num_decoder_layers,
+                               self.shared,
+                               cache_config=cache_config,
+                               quant_config=quant_config,
+                               prefix=f"{prefix}.decoder")
+
+    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
+        return self.shared(input_ids)
+
+    def forward(self, input_ids: torch.Tensor, encoder_input_ids: torch.Tensor,
+                kv_caches: List[torch.Tensor],
+                attn_metadata: AttentionMetadata) -> torch.Tensor:
+        encoder_hidden_states = None
+
+        if encoder_input_ids.numel() > 0:
+            # Run encoder attention if a non-zero number of encoder tokens
+            # are provided as input: on a regular generate call, the encoder
+            # runs once, on the prompt. Subsequent decoder calls reuse output
+            # `encoder_hidden_states`.
+            encoder_hidden_states = self.encoder(input_ids=encoder_input_ids,
+                                                 kv_caches=kv_caches,
+                                                 attn_metadata=attn_metadata)
+            # Clear attention bias state.
+            attn_metadata.attn_bias = None
+            attn_metadata.encoder_attn_bias = None
+            attn_metadata.cross_attn_bias = None
+
+        decoder_outputs = self.decoder(
+            input_ids=input_ids,
+            encoder_hidden_states=encoder_hidden_states,
+            kv_caches=kv_caches,
+            attn_metadata=attn_metadata)
+
+        # When capturing CUDA Graph
+        attn_metadata.attn_bias = None
+        attn_metadata.encoder_attn_bias = None
+        attn_metadata.cross_attn_bias = None
+        return decoder_outputs
+
+
+class T5ForConditionalGeneration(nn.Module):
+    _keys_to_ignore_on_load_unexpected = [
+        "decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight",
+    ]
+    _tied_weights_keys = [
+        "encoder.embed_tokens.weight", "decoder.embed_tokens.weight",
+        "lm_head.weight"
+    ]
+
+    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
+        super().__init__()
+        config: T5Config = vllm_config.model_config.hf_config
+        self.model_dim = config.d_model
+        self.config = config
+        self.unpadded_vocab_size = config.vocab_size
+        if lora_config := vllm_config.lora_config:
+            self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
+
+        self.model = T5Model(vllm_config=vllm_config,
+                             prefix=maybe_prefix(prefix, "model"))
+        # Although not in config, this is the default for hf models.
+        if self.config.tie_word_embeddings:
+            self.lm_head = self.model.shared
+            # in transformers this is smt more explicit, as in (after load)
+            # self.lm_head.weight = self.model.shared.weight
+        else:
+            self.lm_head = ParallelLMHead(self.unpadded_vocab_size,
+                                          config.d_model,
+                                          org_num_embeddings=config.vocab_size)
+
+        self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
+                                                config.vocab_size)
+        self.sampler = get_sampler()
+
+    def compute_logits(
+        self,
+        hidden_states: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[torch.Tensor]:
+        if self.config.tie_word_embeddings:
+            # Rescale output before projecting on vocab
+            # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 # noqa: E501
+            hidden_states = hidden_states * (self.model_dim**-0.5)
+        logits = self.logits_processor(self.lm_head, hidden_states,
+                                       sampling_metadata)
+        return logits
+
+    def sample(
+        self,
+        logits: Optional[torch.Tensor],
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[SamplerOutput]:
+        next_tokens = self.sampler(logits, sampling_metadata)
+        return next_tokens
+
+    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
+        return self.model.shared(input_ids)
+
+    def forward(
+        self,
+        input_ids: torch.Tensor,
+        positions: torch.Tensor,
+        kv_caches: List[torch.Tensor],
+        attn_metadata: AttentionMetadata,
+        intermediate_tensors: Optional[IntermediateTensors] = None,
+        *,
+        encoder_input_ids: torch.Tensor,
+        encoder_positions: torch.Tensor,
+        **kwargs,
+    ) -> torch.Tensor:
+        return self.model(input_ids, encoder_input_ids, kv_caches,
+                          attn_metadata)
+
+    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+        model_params_dict = dict(self.named_parameters(remove_duplicate=False))
+        loaded_params: Set[str] = set()
+        renamed_reg = [
+            (re.compile(r'block\.(\d+)\.layer\.0'), r'blocks.\1.self_attn'),
+            (re.compile(r'decoder.block\.(\d+)\.layer\.1'),
+             r'decoder.blocks.\1.cross_attn'),
+            (re.compile(r'decoder.block\.(\d+)\.layer\.2'),
+             r'decoder.blocks.\1.ffn'),
+            # encoder has no cross-attn, but rather self-attention+ffn.
+            (re.compile(r'encoder.block\.(\d+)\.layer\.1'),
+             r'encoder.blocks.\1.ffn'),
+            (re.compile(r'\.o\.'), r'.out_proj.'),
+        ]
+        stacked_params_mapping = [
+            # (param_name, shard_name, shard_id)
+            (".qkv_proj.", ".q.", "q"),
+            (".qkv_proj.", ".k.", "k"),
+            (".qkv_proj.", ".v.", "v")
+        ]
+
+        for name, loaded_weight in weights:
+            # No relative position attn bias on cross attention.
+            if name in self._keys_to_ignore_on_load_unexpected:
+                continue
+
+            # Handle some renaming
+            for reg in renamed_reg:
+                name = re.sub(*reg, name)
+
+            top_module, _ = name.split('.', 1)
+            if top_module != 'lm_head':
+                name = f"model.{name}"
+
+            # Split q/k/v layers to unified QKVParallelLinear
+            for (param_name, weight_name, shard_id) in stacked_params_mapping:
+                if weight_name not in name:
+                    continue
+                name = name.replace(weight_name, param_name)
+                param = model_params_dict[name]
+                weight_loader = param.weight_loader
+                weight_loader(param, loaded_weight, shard_id)
+                break
+            else:
+                # Not a q/k/v layer.
+                param = model_params_dict[name]
+                weight_loader = getattr(param, "weight_loader",
+                                        default_weight_loader)
+                weight_loader(param, loaded_weight)
+            loaded_params.add(name)
+        return loaded_params