Skip to content

Hybrid recurrent cache #13979

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 18 commits into
base: master
Choose a base branch
from

Conversation

gabe-l-hart
Copy link
Contributor

This is a re-opened version of #13904 after #13746 was merged

Description

This PR introduces the llama_kv_cache_hybrid_recurrent cache implementation. It follows the pattern of llama_kv_cache_unified_iswa by holding two child cache instances and implementing the interface logic such that it manages both correctly for the appropriate layers.

Changes

The main change in this PR is the addition of llama_kv_cache_hybrid_recurrent in llama-kv-cache-hybrid-recurrent.*. In addition to this, the PR does the following:

  • Add the llama_model_is_hybrid_recurrent public API (akin to llama_model_is_recurrent)
  • Add LLM_KV_ATTENTION_LAYER_INDICES as an hparam to hold the indices of the layers that should use attention (versus recurrent)
    • This part is not well aligned with iswa, but that mechanism also isn't particularly extensible. It might be more appropriate to have a generic mechanism for indicating the type of caching to use for each layer, but that would start to approach the generic hybrid implementation that I originally attempted which ended up being too abstract (feat: Hybrid unified/recurrent cache #13276).
  • Abstracting utilities in llm_graph_context that need a specific type of cache to use getters (get_state_unified / get_state_recurrent) that will properly handle llama_kv_cache_hybrid_recurrent
  • Make n_embd_k_s / n_embd_v_s layer-dependent and use layer indices when calling them in the existing cache implementations
  • Add layer filtering to llama_kv_cache_recurrent
  • Updates the logic in llama_model::create_memory to use llm_arch_is_recurrent and llm_arch_is_hybrid_recurrent rather than relying on adding models to the switch statement which was redundant with the implementation of these functions

@gabe-l-hart gabe-l-hart mentioned this pull request Jun 2, 2025
1 task
Comment on lines 1063 to 1086
const llama_kv_cache_unified_state * llm_graph_context::get_state_unified() const {
const auto * umstate = dynamic_cast<const llama_kv_cache_unified_state *>(mstate);
if (!umstate) {
const auto hmstate = dynamic_cast<const llama_kv_cache_hybrid_recurrent_state *>(mstate);
if (hmstate) {
umstate = hmstate->get_state_attn();
}
}
GGML_ASSERT(umstate);
return umstate;
}

const llama_kv_cache_recurrent_state * llm_graph_context::get_state_recurrent() const {
const auto * rmstate = dynamic_cast<const llama_kv_cache_recurrent_state *>(mstate);
if (!rmstate) {
const auto hmstate = dynamic_cast<const llama_kv_cache_hybrid_recurrent_state *>(mstate);
if (hmstate) {
rmstate = hmstate->get_state_recurrent();
}
}
GGML_ASSERT(rmstate);
return rmstate;
}

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These dynamic casts should not be necessary. Instead you need a new llm_graph_context::build_attn_inp_kv_hybrid_recurrent() method, similar to llm_graph_context::build_attn_inp_kv_unified_iswa().

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm working through this now and a couple of questions are coming up:

  1. Would it be best to combine build_inp_s_copy with build_attn_inp_kv for hybrid so that models call just one "build inputs" function, or keep them separate for simplicity?
  2. For the build_attn methods, each has a corresponding llm_graph_input_attn_* class. The build_inp_s_* methods don't have this pattern which would make this a bit harder to have code reuse. Are there plans to refactor that further @compilade?
  3. In the mamba2 branch, s_mask seems to be totally removed. I'd prefer not to do all of the boilerplate for duplicating build_inp_s_mask for the hybrid recurrent case if that's definitely going to be going away. Is there any reason that might stick around past the merge of mamba2?

Copy link
Collaborator

@compilade compilade Jun 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Answering out of order, but it should still make sense:

2. For the build_attn methods, each has a corresponding llm_graph_input_attn_* class. The build_inp_s_* methods don't have this pattern

They do follow this pattern, see

class llm_graph_input_s_copy : public llm_graph_input_i {

(this is on the current master)

I think you might mean the build_attn_* methods also return instances of llm_graph_input_attn_*?
That seems to be directly related to llm_graph_context::build_attn() which has multiple implementations which differ by the type of the first argument (e.g. for llm_graph_input_attn_kv_unified, llm_graph_input_attn_no_cache, etc.)

Are there plans to refactor that further @compilade?

Not really, outside of removing s_mask (and related functions and classes) as part of #13834.

  1. Would it be best to combine build_inp_s_copy with build_attn_inp_kv for hybrid so that models call just one "build inputs" function, or keep them separate for simplicity?

Personally, I think it would be simpler to keep them separate, because they are fundamentally different (one is intended to be used by build_copy_mask_state (renamed to build_recurrent_state in #13834), while the other is used by build_attn), and they are pretty much independent, even in hybrid models (at least for Jamba, the recurrent and self-attention layers are mostly independent on that front).

I don't see how build_attn would ever need s_copy.

build_inp_s_copy and build_inp_attn_kv_* are called once at the beginning of the graph, while build_attn and build_recurrent_state are called once per layer (where applicable, and so usually different layers for both).

3. Is there any reason [s_mask] might stick around past the merge of mamba2?

No reason to keep it, s_mask will be removed. Its functionality is redundant with s_copy, and otherwise prevents minimizing unnecessary state copies. It was used to clear the states, but the same can be done through inp_s_copy and clearing by copying a zero-ed state (which is the rs_z'th state in the mamba2 branch (and #13834)).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, that's super helpful! I was missing the distinction between build_attn_inp and build_attn which makes perfect sense.

Personally, I think it would be simpler to keep them separate

I agree on my personal gut feeling, so I'll go with this and see how it feels once complete.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I think this feels a lot cleaner now. For build_inp_s_copy, I opted to add an optional parameter so that the caller can take ownership of casting the cache state rather than duplicating the function into build_inp_s_copy_hybrid. That felt a little cleaner w.r.t. code reuse, but I'm happy to do a separate method if that's preferred.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like there's one more place that will need changing in build_copy_mask_state (renamed to build_recurrent_state on mamba2). Similar to build_inp_s_copy, I think the cleanest way to do this for code reuse is to add an optional parameter that, if unset, will use the current logic of casting mstate.

Comment on lines +118 to +75
// TODO: will the recurrent cache be in an undefined state at this point?
LLAMA_LOG_ERROR("%s: failed to prepare recurrent ubatches\n", __func__);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but that will be fixed in #13834

(Noting here in case this gets merged first so that I don't forget to update the comment)

Comment on lines 153 to 154
// TODO: Should this return true if the attention cache can shift?
return false;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be true, and for recurrent as well.

I don't know why this was ever set to false for the recurrent cache, since it always supported context shifts (or at least position shifts).

@gabe-l-hart gabe-l-hart force-pushed the HybridRecurrentCache branch 4 times, most recently from 1990f3b to 85d2917 Compare June 5, 2025 20:07
Also, split llama_model_is_recurrent into llm_arch_is_recurrent in
llama-arch with llama_model_is_recurrent delegating to
llm_arch_is_recurrent. The same split is done for hybird. This is needed
because there are places where the llama_model has not yet been initialized
but we need to check if the model is recurrent (specifically for the
per-layer recurrent check array in hparams).

Branch: GraniteFour

Signed-off-by: Gabe Goodhart <[email protected]>
…s in hparams

Branch: GraniteFour

Signed-off-by: Gabe Goodhart <[email protected]>
…l is recurrent

Branch: GraniteFour

Signed-off-by: Gabe Goodhart <[email protected]>
The implementation of the hybrid cache intentionally does not specify the
types of the child caches, so there was a naming mismatch with these
predicate functions that used "hybrid" to imply "hybrid recurrent."

Branch: HybridCache

Signed-off-by: Gabe Goodhart <[email protected]>
Branch: HybridCache

Signed-off-by: Gabe Goodhart <[email protected]>
Branch: GraniteFour

Signed-off-by: Gabe Goodhart <[email protected]>
This follows the pattern in iswa where the two child caches are held
explicitly to support the case where a model requires a single attention
cache and a single recurrent cache where each layer uses exactly one of the
caches.

This is a rewrite of the more generic approach in the original hybrid cache
PR: ggml-org#13276

Branch: HybridRecurrentCache

Signed-off-by: Gabe Goodhart <[email protected]>
This includes a refactor of the create_memory logic to avoid needing to use
the arch enum explicitly unless a model needs explicit cache instantiation
logic beyond the standard logic for recurrent, hybrid, unified, and iswa.

Branch: HybridRecurrentCache

Signed-off-by: Gabe Goodhart <[email protected]>
Branch: HybridRecurrentCache

Signed-off-by: Gabe Goodhart <[email protected]>
Branch: HybridRecurrentCache

Signed-off-by: Gabe Goodhart <[email protected]>
NOTE: I intentionally did not add support for s_mask since it will be going
away soon

Branch: HybridRecurrentCache

Signed-off-by: Gabe Goodhart <[email protected]>
…lits in unified cache

Branch: GraniteFour

Signed-off-by: Gabe Goodhart <[email protected]>
…he interface

Branch: HybridRecurrentCache

Signed-off-by: Gabe Goodhart <[email protected]>
Branch: GraniteFour

Signed-off-by: Gabe Goodhart <[email protected]>
@gabe-l-hart gabe-l-hart force-pushed the HybridRecurrentCache branch from 8a7e8ef to 6cf35ff Compare June 6, 2025 15:38
Branch: HybridRecurrentCache

Signed-off-by: Gabe Goodhart <[email protected]>
@gabe-l-hart gabe-l-hart force-pushed the HybridRecurrentCache branch from 6cf35ff to ab918bb Compare June 6, 2025 15:38
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants