Skip to content

Commit 32177c3

Browse files
committed
feat: First pass at llama_kv_cache_hybrid_recurrent
This follows the pattern in iswa where the two child caches are held explicitly to support the case where a model requires a single attention cache and a single recurrent cache where each layer uses exactly one of the caches. This is a rewrite of the more generic approach in the original hybrid cache PR: ggml-org#13276 Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart <[email protected]>
1 parent 13eeda7 commit 32177c3

File tree

3 files changed

+382
-0
lines changed

3 files changed

+382
-0
lines changed

src/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ add_library(llama
2424
llama-kv-cache-unified.cpp
2525
llama-kv-cache-unified-iswa.cpp
2626
llama-kv-cache-recurrent.cpp
27+
llama-kv-cache-hybrid-recurrent.cpp
2728
llama-memory.cpp
2829
llama-mmap.cpp
2930
llama-model-loader.cpp
Lines changed: 241 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,241 @@
1+
#include "llama-kv-cache-hybrid-recurrent.h"
2+
3+
#include "llama-impl.h"
4+
#include "llama-model.h"
5+
#include "llama-context.h"
6+
7+
//
8+
// llama_kv_cache_hybrid_recurrent
9+
//
10+
11+
llama_kv_cache_hybrid_recurrent::llama_kv_cache_hybrid_recurrent(
12+
const llama_model & model,
13+
/* attn */
14+
ggml_type attn_type_k,
15+
ggml_type attn_type_v,
16+
bool attn_v_trans,
17+
uint32_t attn_kv_size,
18+
uint32_t attn_n_pad,
19+
uint32_t attn_n_swa,
20+
llama_swa_type attn_swa_type,
21+
/* recurrent */
22+
ggml_type recurrent_type_k,
23+
ggml_type recurrent_type_v,
24+
uint32_t recurrent_kv_size,
25+
/* common */
26+
uint32_t n_seq_max,
27+
bool offload) :
28+
hparams(model.hparams),
29+
kv_attn(new llama_kv_cache_unified(
30+
model,
31+
[&](int32_t il) { return !model.hparams.recurrent_layer(il); },
32+
attn_type_k,
33+
attn_type_v,
34+
attn_v_trans,
35+
offload,
36+
attn_kv_size,
37+
n_seq_max,
38+
attn_n_pad,
39+
attn_n_swa,
40+
attn_swa_type
41+
)),
42+
kv_recurrent(new llama_kv_cache_recurrent(
43+
model,
44+
[&](int32_t il) { return model.hparams.recurrent_layer(il); },
45+
recurrent_type_k,
46+
recurrent_type_v,
47+
offload,
48+
recurrent_kv_size,
49+
n_seq_max
50+
)) {}
51+
52+
void llama_kv_cache_hybrid_recurrent::clear() {
53+
kv_attn ->clear();
54+
kv_recurrent->clear();
55+
}
56+
57+
bool llama_kv_cache_hybrid_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
58+
// Try removing from the recurrent cache first since it may fail. If it does
59+
// fail, the cache will not have been mutated.
60+
if (!kv_recurrent->seq_rm(seq_id, p0, p1)) {
61+
return false;
62+
}
63+
return kv_attn->seq_rm(seq_id, p0, p1);
64+
}
65+
66+
void llama_kv_cache_hybrid_recurrent::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
67+
kv_attn ->seq_cp(seq_id_src, seq_id_dst, p0, p1);
68+
kv_recurrent->seq_cp(seq_id_src, seq_id_dst, p0, p1);
69+
}
70+
71+
void llama_kv_cache_hybrid_recurrent::seq_keep(llama_seq_id seq_id) {
72+
kv_attn ->seq_keep(seq_id);
73+
kv_recurrent->seq_keep(seq_id);
74+
}
75+
76+
void llama_kv_cache_hybrid_recurrent::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
77+
kv_attn->seq_add(seq_id, p0, p1, shift);
78+
kv_recurrent->seq_add(seq_id, p0, p1, shift);
79+
}
80+
81+
void llama_kv_cache_hybrid_recurrent::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
82+
kv_attn ->seq_div(seq_id, p0, p1, d);
83+
kv_recurrent->seq_div(seq_id, p0, p1, d);
84+
}
85+
86+
llama_pos llama_kv_cache_hybrid_recurrent::seq_pos_min(llama_seq_id seq_id) const {
87+
// the min of the total cache is the max of the two caches' min values
88+
return std::max(kv_attn->seq_pos_min(seq_id), kv_recurrent->seq_pos_min(seq_id));
89+
}
90+
91+
llama_pos llama_kv_cache_hybrid_recurrent::seq_pos_max(llama_seq_id seq_id) const {
92+
// the max of the total cache is the min of the two caches' max values
93+
return std::min(kv_attn->seq_pos_max(seq_id), kv_recurrent->seq_pos_max(seq_id));
94+
}
95+
96+
llama_memory_state_ptr llama_kv_cache_hybrid_recurrent::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled, bool logits_all) {
97+
98+
// since this includes a recurrent cache, we cannot use split_simple
99+
auto sbatch = llama_sbatch(batch, hparams.n_embd, true, logits_all);
100+
101+
// follow the recurrent pattern for creating the ubatch splits
102+
std::vector<llama_ubatch> ubatches;
103+
while (sbatch.n_tokens > 0) {
104+
llama_ubatch ubatch;
105+
106+
if (embd_pooled) {
107+
// Pooled embeddings cannot be split across ubatches (yet)
108+
ubatch = sbatch.split_seq(n_ubatch);
109+
} else {
110+
ubatch = sbatch.split_equal(n_ubatch);
111+
}
112+
113+
ubatches.push_back(ubatch);
114+
}
115+
116+
// prepare the recurrent batches first
117+
if (!kv_recurrent->prepare(ubatches)) {
118+
// TODO: will the recurrent cache be in an undefined state at this point?
119+
LLAMA_LOG_ERROR("%s: failed to prepare recurrent ubatches\n", __func__);
120+
return std::make_unique<llama_kv_cache_hybrid_recurrent_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
121+
}
122+
123+
// prepare the attention cache
124+
auto heads_attn = kv_attn->prepare(ubatches);
125+
if (heads_attn.empty()) {
126+
LLAMA_LOG_ERROR("%s: failed to prepare attention ubatches\n", __func__);
127+
return std::make_unique<llama_kv_cache_hybrid_recurrent_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
128+
}
129+
130+
return std::make_unique<llama_kv_cache_hybrid_recurrent_state>(
131+
this, std::move(sbatch), std::move(heads_attn), std::move(ubatches));
132+
}
133+
134+
llama_memory_state_ptr llama_kv_cache_hybrid_recurrent::init_full() {
135+
return std::make_unique<llama_kv_cache_hybrid_recurrent_state>(this);
136+
}
137+
138+
bool llama_kv_cache_hybrid_recurrent::update(llama_context & lctx) {
139+
bool res = false;
140+
141+
res = res | kv_attn ->update(lctx);
142+
res = res | kv_recurrent->update(lctx);
143+
144+
return res;
145+
}
146+
147+
void llama_kv_cache_hybrid_recurrent::defrag_sched(float thold) {
148+
kv_attn ->defrag_sched(thold);
149+
kv_recurrent->defrag_sched(thold);
150+
}
151+
152+
bool llama_kv_cache_hybrid_recurrent::get_can_shift() const {
153+
// TODO: Should this return true if the attention cache can shift?
154+
return false;
155+
}
156+
157+
void llama_kv_cache_hybrid_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
158+
kv_attn ->state_write(io, seq_id);
159+
kv_recurrent->state_write(io, seq_id);
160+
}
161+
162+
void llama_kv_cache_hybrid_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
163+
kv_attn ->state_read(io, seq_id);
164+
kv_recurrent->state_read(io, seq_id);
165+
}
166+
167+
llama_kv_cache_unified * llama_kv_cache_hybrid_recurrent::get_kv_attn() const {
168+
return kv_attn.get();
169+
}
170+
171+
llama_kv_cache_recurrent * llama_kv_cache_hybrid_recurrent::get_kv_recurrent() const {
172+
return kv_recurrent.get();
173+
}
174+
175+
llama_kv_cache_hybrid_recurrent_state::llama_kv_cache_hybrid_recurrent_state(llama_memory_status status)
176+
: status(status), state_attn(status), state_recurrent(status) {}
177+
178+
llama_kv_cache_hybrid_recurrent_state::llama_kv_cache_hybrid_recurrent_state(llama_kv_cache_hybrid_recurrent * kv)
179+
: status(LLAMA_MEMORY_STATUS_SUCCESS),
180+
kv(kv),
181+
state_attn(status, kv->get_kv_attn()),
182+
state_recurrent(status, kv->get_kv_recurrent()) {}
183+
184+
llama_kv_cache_hybrid_recurrent_state::llama_kv_cache_hybrid_recurrent_state(
185+
llama_kv_cache_hybrid_recurrent * kv,
186+
llama_sbatch sbatch,
187+
std::vector<uint32_t> heads_attn,
188+
std::vector<llama_ubatch> ubatches)
189+
: status(LLAMA_MEMORY_STATUS_SUCCESS),
190+
kv(kv),
191+
sbatch(std::move(sbatch)),
192+
heads_attn(std::move(heads_attn)),
193+
ubatches(std::move(ubatches)),
194+
// NOTE: these child states are only used as wrapper APIs for the
195+
// const methods, so we use the "init full" signature since the
196+
// actual state is not used.
197+
state_attn(LLAMA_MEMORY_STATUS_SUCCESS, kv->get_kv_attn()),
198+
state_recurrent(LLAMA_MEMORY_STATUS_SUCCESS, kv->get_kv_recurrent()) {}
199+
200+
201+
bool llama_kv_cache_hybrid_recurrent_state::next() {
202+
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
203+
204+
if (++i_next >= ubatches.size()) {
205+
return false;
206+
}
207+
208+
return true;
209+
}
210+
211+
bool llama_kv_cache_hybrid_recurrent_state::apply() {
212+
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
213+
214+
kv->get_kv_attn() ->apply_ubatch(heads_attn[i_next], ubatches[i_next]);
215+
kv->get_kv_recurrent()->find_slot(ubatches[i_next]);
216+
217+
return true;
218+
}
219+
220+
std::vector<int64_t> & llama_kv_cache_hybrid_recurrent_state::out_ids() {
221+
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
222+
223+
return sbatch.out_ids;
224+
}
225+
226+
llama_memory_status llama_kv_cache_hybrid_recurrent_state::get_status() const {
227+
return status;
228+
}
229+
230+
const llama_ubatch & llama_kv_cache_hybrid_recurrent_state::get_ubatch() const {
231+
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
232+
return ubatches[i_next];
233+
}
234+
235+
const llama_kv_cache_unified_state * llama_kv_cache_hybrid_recurrent_state::get_state_attn () const {
236+
return &state_attn;
237+
}
238+
239+
const llama_kv_cache_recurrent_state * llama_kv_cache_hybrid_recurrent_state::get_state_recurrent() const {
240+
return &state_recurrent;
241+
}

src/llama-kv-cache-hybrid-recurrent.h

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
#pragma once
2+
3+
#include "llama-batch.h"
4+
#include "llama-graph.h"
5+
#include "llama-kv-cache.h"
6+
#include "llama-kv-cache-recurrent.h"
7+
#include "llama-kv-cache-unified.h"
8+
9+
#include <memory>
10+
#include <vector>
11+
12+
//
13+
// llama_kv_cache_hybrid_recurrent
14+
//
15+
16+
// utilizes instances of llama_kv_cache_recurrent and llama_kv_cache_unified to
17+
// support models where each layer may be either attention-based or recurrent
18+
19+
class llama_kv_cache_hybrid_recurrent : public llama_kv_cache {
20+
public:
21+
llama_kv_cache_hybrid_recurrent(
22+
const llama_model & model,
23+
/* attn */
24+
ggml_type attn_type_k,
25+
ggml_type attn_type_v,
26+
bool attn_v_trans,
27+
uint32_t attn_kv_size,
28+
uint32_t attn_n_pad,
29+
uint32_t attn_n_swa,
30+
llama_swa_type attn_swa_type,
31+
/* recurrent */
32+
ggml_type recurrent_type_k,
33+
ggml_type recurrent_type_v,
34+
uint32_t recurrent_kv_size,
35+
/* common */
36+
uint32_t n_seq_max,
37+
bool offload);
38+
39+
~llama_kv_cache_hybrid_recurrent() = default;
40+
41+
//
42+
// llama_memory_i
43+
//
44+
45+
void clear() override;
46+
47+
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
48+
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
49+
void seq_keep(llama_seq_id seq_id) override;
50+
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override;
51+
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
52+
53+
llama_pos seq_pos_min(llama_seq_id seq_id) const override;
54+
llama_pos seq_pos_max(llama_seq_id seq_id) const override;
55+
56+
//
57+
// llama_kv_cache
58+
//
59+
60+
llama_memory_state_ptr init_batch(
61+
const llama_batch & batch,
62+
uint32_t n_ubatch,
63+
bool embd_pooled,
64+
bool logits_all) override;
65+
66+
llama_memory_state_ptr init_full() override;
67+
68+
bool update(llama_context & lctx) override;
69+
70+
void defrag_sched(float thold) override;
71+
72+
bool get_can_shift() const override;
73+
74+
// state write/load
75+
76+
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
77+
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;
78+
79+
//
80+
// llama_kv_cache_hybrid_recurrent specific API
81+
//
82+
83+
llama_kv_cache_unified * get_kv_attn () const;
84+
llama_kv_cache_recurrent * get_kv_recurrent() const;
85+
86+
private:
87+
const llama_hparams & hparams;
88+
89+
const std::unique_ptr<llama_kv_cache_unified> kv_attn;
90+
const std::unique_ptr<llama_kv_cache_recurrent> kv_recurrent;
91+
};
92+
93+
class llama_kv_cache_hybrid_recurrent_state : public llama_memory_state_i {
94+
public:
95+
// init failure
96+
explicit llama_kv_cache_hybrid_recurrent_state(llama_memory_status status);
97+
98+
// init full
99+
explicit llama_kv_cache_hybrid_recurrent_state(llama_kv_cache_hybrid_recurrent * kv);
100+
101+
// init success
102+
llama_kv_cache_hybrid_recurrent_state(
103+
llama_kv_cache_hybrid_recurrent * kv,
104+
llama_sbatch sbatch,
105+
std::vector<uint32_t> heads_attn,
106+
std::vector<llama_ubatch> ubatches);
107+
108+
~llama_kv_cache_hybrid_recurrent_state() = default;
109+
110+
bool next() override;
111+
bool apply() override;
112+
113+
std::vector<int64_t> & out_ids() override;
114+
115+
llama_memory_status get_status() const override;
116+
const llama_ubatch & get_ubatch() const override;
117+
118+
//
119+
// llama_kv_cache_hybrid_recurrent_state_i
120+
//
121+
122+
const llama_kv_cache_unified_state * get_state_attn () const;
123+
const llama_kv_cache_recurrent_state * get_state_recurrent() const;
124+
125+
private:
126+
const llama_memory_status status;
127+
128+
llama_kv_cache_hybrid_recurrent * kv;
129+
130+
llama_sbatch sbatch;
131+
132+
// the index of the next ubatch to process
133+
size_t i_next = 0;
134+
135+
std::vector<uint32_t> heads_attn;
136+
std::vector<llama_ubatch> ubatches;
137+
138+
const llama_kv_cache_unified_state state_attn;
139+
const llama_kv_cache_recurrent_state state_recurrent;
140+
};

0 commit comments

Comments
 (0)