@@ -1740,8 +1740,8 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
1740
1740
uint32_t n_seq_max,
1741
1741
uint32_t n_batch,
1742
1742
uint32_t n_pad) : hparams(model.hparams) {
1743
- llama_kv_cache_unified ::layer_filter_cb filter_base = [&](int32_t il) { return !model.hparams .is_swa (il); };
1744
- llama_kv_cache_unified ::layer_filter_cb filter_swa = [&](int32_t il) { return model.hparams .is_swa (il); };
1743
+ llama_kv_cache ::layer_filter_cb filter_base = [&](int32_t il) { return !model.hparams .is_swa (il); };
1744
+ llama_kv_cache ::layer_filter_cb filter_swa = [&](int32_t il) { return model.hparams .is_swa (il); };
1745
1745
1746
1746
const uint32_t size_base = kv_size;
1747
1747
@@ -2975,3 +2975,227 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
2975
2975
2976
2976
return true ;
2977
2977
}
2978
+
2979
+ //
2980
+ // llama_kv_cache_hybrid_recurrent
2981
+ //
2982
+
2983
+ class llama_kv_cache_hybrid_recurrent_decode_state_t : public llama_memory_decode_state_i {
2984
+ public:
2985
+ llama_kv_cache_hybrid_recurrent_decode_state_t (llama_memory_status status) : status(status) {}
2986
+
2987
+ llama_kv_cache_hybrid_recurrent_decode_state_t (
2988
+ llama_kv_cache_hybrid_recurrent * kv,
2989
+ llama_sbatch sbatch,
2990
+ std::vector<uint32_t > heads_attn,
2991
+ std::vector<llama_ubatch> ubatches)
2992
+ : status(LLAMA_MEMORY_STATUS_SUCCESS),
2993
+ kv (kv),
2994
+ sbatch(std::move(sbatch)),
2995
+ heads_attn(std::move(heads_attn)),
2996
+ ubatches(std::move(ubatches)) {
2997
+ }
2998
+
2999
+ ~llama_kv_cache_hybrid_recurrent_decode_state_t () = default ;
3000
+
3001
+ llama_ubatch * next () override {
3002
+ assert (status == LLAMA_MEMORY_STATUS_SUCCESS);
3003
+
3004
+ if (i_next >= ubatches.size ()) {
3005
+ return nullptr ;
3006
+ }
3007
+
3008
+ kv->get_kv_attn () ->fill_slot (heads_attn[i_next], ubatches[i_next]);
3009
+ kv->get_kv_recurrent ()->find_slot (ubatches[i_next]);
3010
+
3011
+ return &ubatches[i_next++];
3012
+ }
3013
+
3014
+ std::vector<int64_t > & out_ids () override {
3015
+ assert (status == LLAMA_MEMORY_STATUS_SUCCESS);
3016
+
3017
+ return sbatch.out_ids ;
3018
+ }
3019
+
3020
+ llama_memory_status get_status () const override {
3021
+ return status;
3022
+ }
3023
+
3024
+ private:
3025
+ const llama_memory_status status;
3026
+
3027
+ llama_kv_cache_hybrid_recurrent * kv;
3028
+
3029
+ llama_sbatch sbatch;
3030
+
3031
+ // the index of the next ubatch to process
3032
+ size_t i_next = 0 ;
3033
+
3034
+ std::vector<uint32_t > heads_attn;
3035
+ std::vector<llama_ubatch> ubatches;
3036
+ };
3037
+
3038
+ llama_kv_cache_hybrid_recurrent::llama_kv_cache_hybrid_recurrent (
3039
+ const llama_model & model,
3040
+ /* attn */
3041
+ ggml_type attn_type_k,
3042
+ ggml_type attn_type_v,
3043
+ bool attn_v_trans,
3044
+ uint32_t attn_kv_size,
3045
+ uint32_t attn_n_pad,
3046
+ uint32_t attn_n_swa,
3047
+ llama_swa_type attn_swa_type,
3048
+ /* recurrent */
3049
+ ggml_type recurrent_type_k,
3050
+ ggml_type recurrent_type_v,
3051
+ uint32_t recurrent_kv_size,
3052
+ /* common */
3053
+ uint32_t n_seq_max,
3054
+ bool offload) :
3055
+ hparams(model.hparams),
3056
+ kv_attn(new llama_kv_cache_unified(
3057
+ model,
3058
+ [&](int32_t il) { return !model.hparams .recurrent_layer (il); },
3059
+ attn_type_k,
3060
+ attn_type_v,
3061
+ attn_v_trans,
3062
+ offload,
3063
+ attn_kv_size,
3064
+ n_seq_max,
3065
+ attn_n_pad,
3066
+ attn_n_swa,
3067
+ attn_swa_type
3068
+ )),
3069
+ kv_recurrent (new llama_kv_cache_recurrent(
3070
+ model,
3071
+ [&](int32_t il) { return model.hparams .recurrent_layer (il); },
3072
+ recurrent_type_k,
3073
+ recurrent_type_v,
3074
+ offload,
3075
+ recurrent_kv_size,
3076
+ n_seq_max
3077
+ )) {}
3078
+
3079
+ void llama_kv_cache_hybrid_recurrent::clear () {
3080
+ kv_attn ->clear ();
3081
+ kv_recurrent->clear ();
3082
+ }
3083
+
3084
+ bool llama_kv_cache_hybrid_recurrent::seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
3085
+ // Try removing from the recurrent cache first since it may fail. If it does
3086
+ // fail, the cache will not have been mutated.
3087
+ if (!kv_recurrent->seq_rm (seq_id, p0, p1)) {
3088
+ return false ;
3089
+ }
3090
+ return kv_attn->seq_rm (seq_id, p0, p1);
3091
+ }
3092
+
3093
+ void llama_kv_cache_hybrid_recurrent::seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
3094
+ kv_attn ->seq_cp (seq_id_src, seq_id_dst, p0, p1);
3095
+ kv_recurrent->seq_cp (seq_id_src, seq_id_dst, p0, p1);
3096
+ }
3097
+
3098
+ void llama_kv_cache_hybrid_recurrent::seq_keep (llama_seq_id seq_id) {
3099
+ kv_attn ->seq_keep (seq_id);
3100
+ kv_recurrent->seq_keep (seq_id);
3101
+ }
3102
+
3103
+ void llama_kv_cache_hybrid_recurrent::seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
3104
+ kv_attn->seq_add (seq_id, p0, p1, shift);
3105
+ kv_recurrent->seq_add (seq_id, p0, p1, shift);
3106
+ }
3107
+
3108
+ void llama_kv_cache_hybrid_recurrent::seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
3109
+ kv_attn ->seq_div (seq_id, p0, p1, d);
3110
+ kv_recurrent->seq_div (seq_id, p0, p1, d);
3111
+ }
3112
+
3113
+ llama_pos llama_kv_cache_hybrid_recurrent::seq_pos_min (llama_seq_id seq_id) const {
3114
+ // the min of the total cache is the max of the two caches' min values
3115
+ return std::max (kv_attn->seq_pos_min (seq_id), kv_recurrent->seq_pos_min (seq_id));
3116
+ }
3117
+
3118
+ llama_pos llama_kv_cache_hybrid_recurrent::seq_pos_max (llama_seq_id seq_id) const {
3119
+ // the max of the total cache is the min of the two caches' max values
3120
+ return std::min (kv_attn->seq_pos_max (seq_id), kv_recurrent->seq_pos_max (seq_id));
3121
+ }
3122
+
3123
+ llama_memory_decode_state_ptr llama_kv_cache_hybrid_recurrent::init (const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled, bool logits_all) {
3124
+
3125
+ // since this includes a recurrent cache, we cannot use split_simple
3126
+ auto sbatch = llama_sbatch (batch, hparams.n_embd , true , logits_all);
3127
+
3128
+ // follow the recurrent pattern for creating the ubatch splits
3129
+ std::vector<llama_ubatch> ubatches;
3130
+ while (sbatch.n_tokens > 0 ) {
3131
+ llama_ubatch ubatch;
3132
+
3133
+ if (embd_pooled) {
3134
+ // Pooled embeddings cannot be split across ubatches (yet)
3135
+ ubatch = sbatch.split_seq (n_ubatch);
3136
+ } else {
3137
+ ubatch = sbatch.split_equal (n_ubatch);
3138
+ }
3139
+
3140
+ ubatches.push_back (ubatch);
3141
+ }
3142
+
3143
+ // prepare the recurrent batches first
3144
+ if (!kv_recurrent->prepare (ubatches)) {
3145
+ // TODO: will the recurrent cache be in an undefined state at this point?
3146
+ LLAMA_LOG_ERROR (" %s: failed to prepare recurrent ubatches\n " , __func__);
3147
+ return std::make_unique<llama_kv_cache_hybrid_recurrent_decode_state_t >(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
3148
+ }
3149
+
3150
+ // prepare the attention cache
3151
+ auto heads_attn = kv_attn->prepare (ubatches);
3152
+ if (heads_attn.empty ()) {
3153
+ LLAMA_LOG_ERROR (" %s: failed to prepare attention ubatches\n " , __func__);
3154
+ return std::make_unique<llama_kv_cache_hybrid_recurrent_decode_state_t >(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
3155
+ }
3156
+
3157
+ return std::make_unique<llama_kv_cache_hybrid_recurrent_decode_state_t >(
3158
+ this , std::move (sbatch), std::move (heads_attn), std::move (ubatches));
3159
+ }
3160
+
3161
+ bool llama_kv_cache_hybrid_recurrent::update (llama_context & lctx) {
3162
+ bool res = false ;
3163
+
3164
+ res = res | kv_attn ->update (lctx);
3165
+ res = res | kv_recurrent->update (lctx);
3166
+
3167
+ return res;
3168
+ }
3169
+
3170
+ void llama_kv_cache_hybrid_recurrent::defrag_sched (float thold) {
3171
+ kv_attn ->defrag_sched (thold);
3172
+ kv_recurrent->defrag_sched (thold);
3173
+ }
3174
+
3175
+ void llama_kv_cache_hybrid_recurrent::set_full () {
3176
+ kv_attn ->set_full ();
3177
+ kv_recurrent->set_full ();
3178
+ }
3179
+
3180
+ bool llama_kv_cache_hybrid_recurrent::get_can_shift () const {
3181
+ // TODO: Should this return true if the attention cache can shift?
3182
+ return false ;
3183
+ }
3184
+
3185
+ void llama_kv_cache_hybrid_recurrent::state_write (llama_io_write_i & io, llama_seq_id seq_id) const {
3186
+ kv_attn ->state_write (io, seq_id);
3187
+ kv_recurrent->state_write (io, seq_id);
3188
+ }
3189
+
3190
+ void llama_kv_cache_hybrid_recurrent::state_read (llama_io_read_i & io, llama_seq_id seq_id) {
3191
+ kv_attn ->state_read (io, seq_id);
3192
+ kv_recurrent->state_read (io, seq_id);
3193
+ }
3194
+
3195
+ llama_kv_cache_unified * llama_kv_cache_hybrid_recurrent::get_kv_attn () const {
3196
+ return kv_attn.get ();
3197
+ }
3198
+
3199
+ llama_kv_cache_recurrent * llama_kv_cache_hybrid_recurrent::get_kv_recurrent () const {
3200
+ return kv_recurrent.get ();
3201
+ }
0 commit comments