Skip to content

Conversation

@i3hz
Copy link
Contributor

@i3hz i3hz commented Nov 28, 2025

What does this PR do?

Fixes #42454

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@zucchini-nlp @Rocketknight1 @mobicham

Benchmarking script -

import torch
import time
import numpy as np
from transformers import WhisperForConditionalGeneration
from transformers.cache_utils import StaticCache, EncoderDecoderCache


MODEL_ID = "openai/whisper-tiny"
DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

MAX_BATCH_SIZE = 64
TEST_BATCHES = [1, 8, 32, 64]

SEQ_LEN = 128
WARMUP = 10
REPEATS = 50


def load_model():
    model = WhisperForConditionalGeneration.from_pretrained(
        MODEL_ID,
        torch_dtype=DTYPE,
        attn_implementation="sdpa",
    ).to(DEVICE)
    model.eval()
    return model


def run_benchmark(model, batch_size, cache_cap, tag):
    decoder = model.model.decoder
    cache_len = SEQ_LEN + 10
    enc_len = SEQ_LEN

    self_cache = StaticCache(
        config=decoder.config,
        max_batch_size=cache_cap,
        max_cache_len=cache_len,
        device=DEVICE,
        dtype=DTYPE,
    )
    cross_cache = StaticCache(
        config=decoder.config,
        max_batch_size=cache_cap,
        max_cache_len=enc_len,
        device=DEVICE,
        dtype=DTYPE,
    )
    kv_cache = EncoderDecoderCache(self_cache, cross_cache)

    input_ids = torch.randint(0, 1000, (batch_size, SEQ_LEN), device=DEVICE)
    encoder_states = torch.randn(batch_size, enc_len, model.config.d_model, device=DEVICE, dtype=DTYPE)
    cache_pos = torch.arange(SEQ_LEN, device=DEVICE)

    for _ in range(WARMUP):
        kv_cache.reset()
        with torch.no_grad():
            decoder(
                input_ids=input_ids,
                encoder_hidden_states=encoder_states,
                past_key_values=kv_cache,
                cache_position=cache_pos,
                use_cache=True,
            )

    # actual runs
    start = [torch.cuda.Event(enable_timing=True) for _ in range(REPEATS)]
    end = [torch.cuda.Event(enable_timing=True) for _ in range(REPEATS)]

    torch.cuda.synchronize()

    for i in range(REPEATS):
        kv_cache.reset()
        start[i].record()

        with torch.no_grad():
            decoder(
                input_ids=input_ids,
                encoder_hidden_states=encoder_states,
                past_key_values=kv_cache,
                cache_position=cache_pos,
                use_cache=True,
            )

        end[i].record()

    torch.cuda.synchronize()
    times = [s.elapsed_time(e) for s, e in zip(start, end)]
    avg = float(np.mean(times))

    print(f"[{tag}]  batch={batch_size:2d}  cache_cap={cache_cap:2d}  latency={avg:.2f} ms")
    return avg



model = load_model()

results = []

for bs in TEST_BATCHES:
    base = run_benchmark(model, batch_size=bs, cache_cap=bs, tag="BASELINE")
    sliced = run_benchmark(model, batch_size=bs, cache_cap=MAX_BATCH_SIZE, tag="SLICED")
    diff = (sliced - base) / base * 100
    results.append((bs, base, sliced, diff))

print("Summary:")
print(f"{'Batch':<8} | {'Baseline (ms)':<15} | {'Sliced (ms)':<15} | Diff (%)")
for bs, base, sliced, diff in results:
    print(f"{bs:<8} | {base:<15.2f} | {sliced:<15.2f} | {diff:+.2f}%")

@i3hz
Copy link
Contributor Author

i3hz commented Nov 28, 2025

@zucchini-nlp This doesn't have the max_batch_size as you mentioned . If it's something that I should add , please lmk .

Copy link
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for a quick fix @i3hz !

I think we need to allow max_batch_size which will take precedence if available when lazy initializing the cache. Early cache initialization is currently used only in export, but we can allow users to re-use cache across several generation with max batch size. It would also require us to change a few places in generation imo

After that, we can verify that there are no unwanted graph breaks and run the bench

LMK if this makes sense and you need guidance

@i3hz
Copy link
Contributor Author

i3hz commented Nov 28, 2025

I think we need to allow max_batch_size which will take precedence if available when lazy initializing the cache. Early cache initialization is currently used only in export, but we can allow users to re-use cache across several generation with max batch size. It would also require us to change a few places in generation imo

So basically I should add max_batch_size to the __init__ method of StaticCache and then in StaticLayer modify the lazy_initialization to use max_batch_size .

And also change line 1837 from src/transformers/generation/utils.py to as seen in #37394

or cache_to_check.max_batch_size < batch_size

@zucchini-nlp
Copy link
Member

Yep, and a small test as well

@i3hz
Copy link
Contributor Author

i3hz commented Nov 29, 2025

hi @zucchini-nlp I'm still stuck on this. I’ve been testing with torch.compile and it works fine with GPT-2, but does not work with whisper small ,I’m not sure what I’m missing tbh .
If you have any pointers on what I should check or tweak, I’d really appreciate it.
Thanks a lot and sorry for the trouble

k_out = self.keys
v_out = self.values
batch_size = key_states.shape[0]
if k_out.shape[0] != batch_size:
Copy link
Contributor

@mobicham mobicham Nov 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess k_out.shape[0] >= batch_size is better

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When debugging the torch.compile stuff, can you check this:

assert k_out.data_ptr() == k_out[:batch_size].data_ptr() , "invalid k_out data copy()!"
assert v_out.data_ptr() == v_out[:batch_size].data_ptr() , "invalid v_out data copy()!"

If there's no copy, I don't see why Cudagraphs would break with Whisper.
What error do you get exactly btw?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess k_out.shape[0] <= batch_size is better

Wait should it be < or > considering k_out will be larger than the batch_size

assert k_out.data_ptr() == k_out[:batch_size].data_ptr() , "invalid k_out data copy()!"

Builtin `operator.*` comparison with constant `self` failed
  Explanation: Failed to compare DataPtrVariable() with DataPtrVariable(), because DataPtrVariable() is not a Python constant or its mutation check fails.

About the actual torch.compile error i'm getting i'm trying max_batch_size = 8 and the list being 8,4,2,1
on 4 it crashes with

Dynamo failed to run FX node with fake tensors: call_function <built-in function scaled_dot_product_attention>(*(FakeTensor(..., device='cuda:0', size=(s72, 6, 1, 64), dtype=torch.float16,
           grad_fn=<TransposeBackward0>), FakeTensor(..., device='cuda:0', size=(8, 6, 32, 64), dtype=torch.float16,
           grad_fn=<Error>), FakeTensor(..., device='cuda:0', size=(8, 6, 32, 64), dtype=torch.float16,
           grad_fn=<Error>)), **{'attn_mask': None, 'dropout_p': 0.0, 'scale': 1.0, 'is_causal': False}): got RuntimeError('Attempting to broadcast a dimension of length 8 at -2! Mismatching argument at index 1 had [8, 6]; but expected shape should be broadcastable to [s72, 6]')

from user code:
   File "/home/vedth/stuhdy/z.py", line 21, in decoder_forward
    out = model.model.decoder(
  File "/home/vedth/stuhdy/transformers/src/transformers/models/whisper/modeling_whisper.py", line 865, in forward
    layer_outputs = decoder_layer(
  File "/home/vedth/stuhdy/transformers/src/transformers/modeling_layers.py", line 94, in __call__
    return super().__call__(*args, **kwargs)
  File "/home/vedth/stuhdy/transformers/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/vedth/stuhdy/transformers/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/vedth/stuhdy/transformers/src/transformers/models/whisper/modeling_whisper.py", line 501, in forward
    hidden_states, cross_attn_weights = self.encoder_attn(
  File "/home/vedth/stuhdy/transformers/src/transformers/models/whisper/modeling_whisper.py", line 347, in forward
    attn_output, attn_weights = attention_interface(
  File "/home/vedth/stuhdy/transformers/src/transformers/integrations/sdpa_attention.py", line 92, in sdpa_attention_forward
    attn_output = torch.nn.functional.scaled_dot_product_attention(

Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wait should it be < or > considering k_out will be larger than the batch_size

Oh sorry you're right, I meant current_batch_size <= max_batch_size

assert k_out.data_ptr() == k_out[:batch_size].data_ptr() , "invalid k_out data copy()!"

I meant run it without torch.compile, just to see if it performs any copy

Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace

I see, I will try to debug next week too 👍

@mobicham
Copy link
Contributor

@i3hz I will also do some debugging next week

@i3hz
Copy link
Contributor Author

i3hz commented Nov 29, 2025

@i3hz I will also do some debugging next week

Thanks a lot
The main issue still lies within torch.compile as without it the model is working

@mobicham
Copy link
Contributor

mobicham commented Dec 1, 2025

@i3hz I tried the slicing solution but it throws an attention error even without torch.compile:

RuntimeError: The size of tensor a (4) must match the size of tensor b (8) at non-singleton dimension 0

Something is strange, this works:

for bs in [8, 4, 2, 1]:
    past_key_values = create_cache(max_batch_size)
...

but when the cache is allocated only once, it throws that error:

past_key_values = create_cache(max_batch_size)
for bs in [8, 4, 2, 1]:
...
import torch
from transformers import WhisperForConditionalGeneration, AutoProcessor
from transformers.cache_utils import StaticCache, EncoderDecoderCache
import math
import numpy as np

device = 'cuda:0'
torch_dtype = torch.float16
model_id = "openai/whisper-large-v3-turbo"

model = WhisperForConditionalGeneration.from_pretrained(model_id, torch_dtype=torch_dtype, attn_implementation="sdpa", device_map=device) 
model.generation_config.cache_implementation = "static"

@torch.no_grad()
def run_encoder(model, labels, encoder_outputs, past_key_values, prefill: bool):

    seq_length = labels.shape[-1]
    if(prefill):
        cache_position = torch.arange(seq_length, device=device)
    else:
        cache_position = torch.tensor([seq_length], device=device)

    out_decoder = model.model.decoder(
        labels, 
        encoder_hidden_states=encoder_outputs,
        past_key_values = past_key_values,
        cache_position=cache_position,
        use_cache = True,
        return_dict=True,
    ) 

    cur_token = model.proj_out(out_decoder.last_hidden_state[:,-1:]).argmax(axis=-1) 
    past_key_values = out_decoder.past_key_values

    return cur_token, past_key_values

max_batch_size = 32
max_cache_len = 256
enc_len = 1500
decoder = model.model.decoder

################################################
from transformers import cache_utils
from typing import Optional, Any

def update(
    self,
    key_states: torch.Tensor,
    value_states: torch.Tensor,
    cache_kwargs: Optional[dict[str, Any]] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
    if not self.is_initialized:
        self.lazy_initialization(key_states)

    cache_position = cache_kwargs.get("cache_position") if cache_kwargs is not None else None
    cache_position = (
        cache_position if cache_position is not None else torch.arange(key_states.shape[-2], device=self.device)
    )

    batch_size = key_states.shape[0]
    assert batch_size <= self.max_batch_size, f"Current batch-size {batch_size} should be <= max_batch_size ({self.max_batch_size})"

    print(f"{batch_size}:{self.max_batch_size}")

    k_out = self.keys[:batch_size]
    v_out = self.values[:batch_size]
    # Update the cache
    try:
        k_out.index_copy_(2, cache_position, key_states)
        v_out.index_copy_(2, cache_position, value_states)
    except NotImplementedError:
        # Fallback for devices like MPS where index_copy_ might not be supported.
        k_out[:, :, cache_position] = key_states
        v_out[:, :, cache_position] = value_states
    return k_out, v_out

cache_utils.StaticLayer.update = update

################################################
def create_cache(max_batch_size):
    # Cache
    self_cache = StaticCache(
        config=decoder.config,
        max_batch_size=max_batch_size,
        max_cache_len=max_cache_len,
        device=device,
        dtype=torch_dtype,
    )

    cross_cache = StaticCache(
        config=decoder.config,
        max_batch_size=max_batch_size,
        max_cache_len=enc_len,
        device=device,
        dtype=torch_dtype,
    )

    return EncoderDecoderCache(self_cache, cross_cache)

#torch._dynamo.config.capture_scalar_outputs = True
#run_encoder = torch.compile(run_encoder, mode='reduce-overhead', fullgraph=True)

max_batch_size = 8 
past_key_values = create_cache(max_batch_size)

for bs in [8, 4, 2, 1]:
    assert bs <= max_batch_size, "batch_size should be <= max_batch_size"
    seq_length = 3
    labels = torch.tensor([[50258, 50259, 50360]] * bs, device=device, dtype=torch.int64)
    
    encoder_outputs = torch.randn([bs, enc_len, 1280], device=device, dtype=torch_dtype)

    cur_token, past_key_values_out = run_encoder(model, labels, encoder_outputs, past_key_values, prefill=True)
    cur_token, past_key_values_out = run_encoder(model, cur_token.clone(), encoder_outputs.clone(), past_key_values_out, prefill=False)

@mobicham
Copy link
Contributor

mobicham commented Dec 1, 2025

So this issue is this part

key_states = past_key_values.layers[self.layer_idx].keys
value_states = past_key_values.layers[self.layer_idx].values

if you replace it with this, it works.

 key_states = past_key_values.layers[self.layer_idx].keys[:bsz]
value_states = past_key_values.layers[self.layer_idx].values[:bsz]

However, the problem is that we can't do this for every modeling file separately. I guess the solution is to do something with self.keys and self.values, like this, it works with torch.compile:
@zucchini-nlp what do you think?

class StaticLayer(CacheLayerMixin):
    """
    A static cache layer that stores the key and value states as static tensors of shape `[batch_size, num_heads, max_cache_len), head_dim]`.
    It lazily allocates its full backing tensors, and then mutates them in-place. Built for `torch.compile` support.

    Args:
        max_cache_len (`int`):
            Maximum number of tokens that can be stored, used for tensor preallocation.
    """

    is_compileable = True
    is_sliding = False

    def __init__(self, max_cache_len: int):
        super().__init__()
        self.max_cache_len = max_cache_len

    def lazy_initialization(self, key_states: torch.Tensor):
        """
        Lazy initialization of the keys and values tensors. This allows to get all properties (dtype, device,
        num_heads in case of TP etc...) at runtime directly, which is extremely practical as it avoids moving
        devices, dtypes etc later on for each `update` (which could break the static dynamo addresses as well).

        If this is unwanted, one can call `early_initialization(...)` on the Cache directly, which will call this
        function ahead-of-time (this is required for `torch.export` for example). Note that for `compile`, as we
        internally don't compile the prefill, this is guaranteed to have been called already when compiling.
        If compiling the prefill as well, e.g. calling `model.compile(...)` before `generate` with a static cache,
        it is still supported in general, but without guarantees depending on the compilation options (e.g. cuda graphs,
        i.e. `mode="reduce-overhead"` is known to fail). But it will in general work correctly, and prefill should
        not be compiled anyway for performances!
        """
        self.max_batch_size, self.num_heads, _, self.head_dim = key_states.shape
        self.dtype, self.device = key_states.dtype, key_states.device

        self.keys_ = torch.zeros(
            (self.max_batch_size, self.num_heads, self.max_cache_len, self.head_dim),
            dtype=self.dtype,
            device=self.device,
        )
        self.values_ = torch.zeros(
            (self.max_batch_size, self.num_heads, self.max_cache_len, self.head_dim),
            dtype=self.dtype,
            device=self.device,
        )

        self.keys, self.values = self.keys_, self.values_

        # Note: `mark_static_address` is used to tag the cache as a fixed data pointer, preventing compiled graph
        # breaks when updating the cache. However, it is not supported when tracing the graph, so we skip it in this case.
        # As prefill should never be compiled, this is not an issue and it will still be run (except when users compile
        # prefill explicitly, but this should be avoided!)
        if not is_torchdynamo_compiling():
            torch._dynamo.mark_static_address(self.keys_)
            torch._dynamo.mark_static_address(self.values_)

        self.is_initialized = True

    def update(
        self,
        key_states: torch.Tensor,
        value_states: torch.Tensor,
        cache_kwargs: Optional[dict[str, Any]] = None,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """
        Update the key and value caches in-place, and return the necessary keys and value states.

        Args:
            key_states (`torch.Tensor`): The new key states to cache.
            value_states (`torch.Tensor`): The new value states to cache.
            cache_kwargs (`dict[str, Any]`, *optional*): Additional arguments for the cache.

        Returns:
            tuple[`torch.Tensor`, `torch.Tensor`]: The key and value states.
        """
        # Lazy initialization
        if not self.is_initialized:
            self.lazy_initialization(key_states)

        # Some old models give None for `cache_position` or even omit passing `cache_kwargs` when used as cross-attention,
        # in which case we should copy the whole Layer (key_states.shape[-2] == self.max_cache_len)
        cache_position = cache_kwargs.get("cache_position") if cache_kwargs is not None else None
        cache_position = (
            cache_position if cache_position is not None else torch.arange(key_states.shape[-2], device=self.device)
        )

        batch_size = key_states.shape[0]
        assert batch_size <= self.max_batch_size, f"Current batch-size {batch_size} should be <= max_batch_size ({self.max_batch_size})"

        self.keys = self.keys_[:batch_size]
        self.values = self.values_[:batch_size]

        # Update the cache
        try:
            self.keys.index_copy_(2, cache_position, key_states)
            self.values.index_copy_(2, cache_position, value_states)
        except NotImplementedError:
            # Fallback for devices like MPS where index_copy_ might not be supported.
            self.keys[:, :, cache_position] = key_states
            self.values[:, :, cache_position] = value_states
        return self.keys, self.values

    def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]:
        """Return the length and offset of the cache, used to generate the attention mask"""
        kv_offset = 0
        kv_length = self.max_cache_len
        return kv_length, kv_offset

    def get_seq_length(self) -> int:
        """Returns the sequence length of the cached states."""
        # Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's
        # limit the check to the first batch member and head dimension.
        return (self.keys[0, 0].any(dim=-1)).sum() if self.is_initialized else 0

    def get_max_cache_shape(self) -> int:
        """Return the maximum cache shape of the cache"""
        return self.max_cache_len

@i3hz
Copy link
Contributor Author

i3hz commented Dec 2, 2025

@i3hz I tried the slicing solution but it throws an attention error even without torch.compile:

You are right it's not working :c .
I switched models to gpt2 and it does work (as I mentioned before) . I really don't know why that's happening , is it because gpt2 does not use EncoderDecoderCache ?
Or I think the problem probably is that in the testing script for whisper we only ran the encoding logic , whereas now we're also trying the decoding logic (which is an oversight on my part ,sorry ) so your suggestion about the self.keys_ and self.values_ might be the correct fix .

My reproduction script which uses gpt2 instead of whisper if you need it (which does successfully run )

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.cache_utils import StaticCache

device = "cuda"
model_id = "openai-community/gpt2" 

model = AutoModelForCausalLM.from_pretrained(
    model_id, 
    dtype=torch.float16, 
    attn_implementation="sdpa",
    ignore_mismatched_sizes=True
).to(device)
model.eval()

def decode_step(model, input_ids, past_key_values, cache_position):
    out = model(
        input_ids=input_ids,
        past_key_values=past_key_values,
        cache_position=cache_position,
        use_cache=True,
    )
    logits = out.logits
    next_token = logits[:, -1, :].argmax(dim=-1, keepdim=True)
    return next_token, out.past_key_values

compiled_decode = torch.compile(decode_step, mode="reduce-overhead", fullgraph=True)

max_batch_size = 8
max_seq_len = 64
dtype = torch.float16

past_key_values = StaticCache(
    config=model.config, 
    max_batch_size=max_batch_size, 
    max_cache_len=max_seq_len, 
    device=device, 
    dtype=dtype
)

batch_sizes = [8, 4, 2, 1]

try:
    for bs in batch_sizes:
        print(f"Batch Size: {bs}")
        past_key_values.reset()
        
        seq_len = 3
        input_ids = torch.randint(0, 1000, (bs, seq_len), device=device)
        cache_position = torch.arange(seq_len, device=device)
        
        with torch.no_grad():
            out = model(input_ids, past_key_values=past_key_values, cache_position=cache_position)
            cur_token = out.logits[:, -1, :].argmax(dim=-1, keepdim=True)

        cache_position = torch.tensor([seq_len], device=device)
        
        cur_token, _ = compiled_decode(model, cur_token, past_key_values, cache_position)
        


    print("Success")
except Exception as e:
    print(f"Failed on {bs} with error {e}")

@mobicham
Copy link
Contributor

mobicham commented Dec 2, 2025

@i3hz yeah because the issue is that, at some point it returns self.keys and self.value , not just for Whisper, but also for other models. The self.keys_ / self.values_ trick works, I think we just need to update the reset() function so that it updates self.keys_ / self.values_ instead

@i3hz
Copy link
Contributor Author

i3hz commented Dec 3, 2025

I've implemented the self.keys_ and self.values_ functionality .
Along with that I also had to override the update ,reset ,__len__ for StaticCache (to triggers updates for cross-attention)
In Static Layer I've overridden the reset method as well .(to correctly reset the cache)
And I've also added max_batch_size parameter in StaticCache and StaticLayer .

So the testing script from earlier does work . But torch compile still fails with a segmentation fault which I'm working on .
Is this the expected fix @zucchini-nlp @mobicham

(misclicked and accidentally closed the pr mb)

@i3hz i3hz closed this Dec 3, 2025
@i3hz i3hz reopened this Dec 3, 2025
@i3hz
Copy link
Contributor Author

i3hz commented Dec 3, 2025

class StaticLayer(CacheLayerMixin):
    """
    A static cache layer that stores the key and value states as static tensors of shape `[batch_size, num_heads, max_cache_len), head_dim]`.
    It lazily allocates its full backing tensors, and then mutates them in-place. Built for `torch.compile` support.

    Args:
        max_cache_len (`int`):
            Maximum number of tokens that can be stored, used for tensor preallocation.
    """

    is_compileable = True
    is_sliding = False

    def __init__(self, max_cache_len: int, max_batch_size: int | None = None):
        super().__init__()
        self.max_cache_len = max_cache_len
        self.max_batch_size = max_batch_size

    def lazy_initialization(self, key_states: torch.Tensor):
        """
        Lazy initialization of the keys and values tensors. This allows to get all properties (dtype, device,
        num_heads in case of TP etc...) at runtime directly, which is extremely practical as it avoids moving
        devices, dtypes etc later on for each `update` (which could break the static dynamo addresses as well).

        If this is unwanted, one can call `early_initialization(...)` on the Cache directly, which will call this
        function ahead-of-time (this is required for `torch.export` for example). Note that for `compile`, as we
        internally don't compile the prefill, this is guaranteed to have been called already when compiling.
        If compiling the prefill as well, e.g. calling `model.compile(...)` before `generate` with a static cache,
        it is still supported in general, but without guarantees depending on the compilation options (e.g. cuda graphs,
        i.e. `mode="reduce-overhead"` is known to fail). But it will in general work correctly, and prefill should
        not be compiled anyway for performances!
        """
        if self.max_batch_size is None:
            self.max_batch_size = key_states.shape[0]
        _, self.num_heads, _, self.head_dim = key_states.shape
        self.dtype, self.device = key_states.dtype, key_states.device

        self.keys_ = torch.zeros(
            (self.max_batch_size, self.num_heads, self.max_cache_len, self.head_dim),
            dtype=self.dtype,
            device=self.device,
        )
        self.values_ = torch.zeros(
            (self.max_batch_size, self.num_heads, self.max_cache_len, self.head_dim),
            dtype=self.dtype,
            device=self.device,
        )
        self.keys = self.keys_
        self.values = self.values_
        # Note: `mark_static_address` is used to tag the cache as a fixed data pointer, preventing compiled graph
        # breaks when updating the cache. However, it is not supported when tracing the graph, so we skip it in this case.
        # As prefill should never be compiled, this is not an issue and it will still be run (except when users compile
        # prefill explicitly, but this should be avoided!)
        if not is_torchdynamo_compiling():
            torch._dynamo.mark_static_address(self.keys_)
            torch._dynamo.mark_static_address(self.values_)

        self.is_initialized = True

    def update(
        self,
        key_states: torch.Tensor,
        value_states: torch.Tensor,
        cache_kwargs: Optional[dict[str, Any]] = None,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """
        Update the key and value caches in-place, and return the necessary keys and value states.

        Args:
            key_states (`torch.Tensor`): The new key states to cache.
            value_states (`torch.Tensor`): The new value states to cache.
            cache_kwargs (`dict[str, Any]`, *optional*): Additional arguments for the cache.

        Returns:
            tuple[`torch.Tensor`, `torch.Tensor`]: The key and value states.
        """
        # Lazy initialization
        if not self.is_initialized:
            self.lazy_initialization(key_states)

        # Some old models give None for `cache_position` or even omit passing `cache_kwargs` when used as cross-attention,
        # in which case we should copy the whole Layer (key_states.shape[-2] == self.max_cache_len)
        cache_position = cache_kwargs.get("cache_position") if cache_kwargs is not None else None
        cache_position = (
            cache_position if cache_position is not None else torch.arange(key_states.shape[-2], device=self.device)
        )
        batch_size = key_states.shape[0]
        # 3. Dynamic Slicing: Update the view to match current batch
        self.keys = self.keys_[:batch_size]
        self.values = self.values_[:batch_size]
        try:
            self.keys.index_copy_(2, cache_position, key_states)
            self.values.index_copy_(2, cache_position, value_states)
        except NotImplementedError:
            self.keys[:, :, cache_position] = key_states
            self.values[:, :, cache_position] = value_states

        return self.keys, self.values

    def reset(self):
        if self.is_initialized:
            self.keys_.zero_()
            self.values_.zero_()
            self.keys = self.keys_
            self.values = self.values_

    def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]:
        """Return the length and offset of the cache, used to generate the attention mask"""
        kv_offset = 0
        kv_length = self.max_cache_len
        return kv_length, kv_offset

    def get_seq_length(self) -> int:
        """Returns the sequence length of the cached states."""
        # Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's
        # limit the check to the first batch member and head dimension.
        return (self.keys[0, 0].any(dim=-1)).sum() if self.is_initialized else 0

    def get_max_cache_shape(self) -> int:
        """Return the maximum cache shape of the cache"""
        return self.max_cache_len

this is my StaticLayer class

@zucchini-nlp
Copy link
Member

However, the problem is that we can't do this for every modeling file separately. I guess the solution is to do something with self.keys and self.values, like this, it works with torch.compile:

Ah right, encoder-decoder ones are a bit different. Naming the kv with max batch size differently sounds good to me. Probably a bit more informative name would be better, it's easy to lose track when reading the code

@i3hz can you push the code you have with all the updates. Also, in which cases you're getting a seg fault, in test files or in bench script? It is important to not compile a prefill stage in custom generation loop, or if we have to compile in advance then cache has to be early initialized. The lazy init function is known to fail when compiled

@i3hz
Copy link
Contributor Author

i3hz commented Dec 3, 2025

The code is a bit messy but I'll change it later sorry

@mobicham
Copy link
Contributor

mobicham commented Dec 4, 2025

Probably a bit more informative name would be better, it's easy to lose track when reading the code

Yeah probably self.keys_, self.values_ ->self.keys_,full self.values_full or something like that

But torch compile still fails with a segmentation fault which I'm working on .
@i3hz do you still have this issue? torch.compile works fine with the self.keys_, self.values_ trick, at least with Whisper, are other models not working too?

@i3hz
Copy link
Contributor Author

i3hz commented Dec 4, 2025

@i3hz do you still have this issue? torch.compile works fine with the self.keys_, self.values_ trick, at least with Whisper, are other models not working too?

It failed on my end . Can you please look into the implementation and lmk if i messed something up?

@i3hz
Copy link
Contributor Author

i3hz commented Dec 5, 2025

Also, in which cases you're getting a seg fault, in test files or in bench script? It is important to not compile a prefill stage in custom generation loop, or if we have to compile in advance then cache has to be early initialized. The lazy init function is known to fail when compiled

Yeah I forgot about this sorry . I split the test into eager prefill + compile decode and it now passes the test successfully .

But it is failing the CI tests , which is something I'm working on

@i3hz i3hz requested review from mobicham and zucchini-nlp December 5, 2025 13:48
@i3hz
Copy link
Contributor Author

i3hz commented Dec 5, 2025

I would really appreciate any pointers on why the CI tests are failing @zucchini-nlp

@zucchini-nlp
Copy link
Member

zucchini-nlp commented Dec 8, 2025

The failures on qwen3 aren't related, that prob happended due to too many requests which caused failure to download image data. I will take a look at this PR around this week, have been quite busy with other tasks lately

In the meantime, can you make sure that PR is ready, and has the final bench script, performance results in the description. Also, I think it's nice to allow users to initialize a cache with max batch size from model.generate() call, so we can pass over the param to the cache init when generating. That way we can also test if the feature aligns well with auto-compile in generation loop

@mobicham
Copy link
Contributor

mobicham commented Dec 8, 2025

Btw can you also do a benchmark with speed, based on the logs, it seems this creates an issue with cudagraphs, I am not sure though if this is critical:

shown is first use of this value--the guard itself is not due user code but due to 0/1 specialization in the framework; to avoid specialization try torch._dynamo.mark_unbacked(tensor, dim)) and attn_output = torch.nn.functional.scaled_dot_product_attention(  # transformers/integrations/sdpa_attention.py:96 in sdpa_attention_forward (_dynamo/utils.py:3421 in run_node)
V1208 17:05:18.623000 743734 torch/_dynamo/guards.py:3508] [0/2] [__recompiles]     - 0/0: tensor 

'fn.__self__.past_key_values.cross_attention_cache.layers[0].keys' size mismatch at index 0. expected 128, actual 1

I1208 17:05:21.047000 743734 torch/_inductor/cudagraph_trees.py:390] [__cudagraphs] recording cudagraph tree for graph without symints
V1208 17:05:21.048000 743734 torch/_inductor/cudagraph_trees.py:2256] [__cudagraphs] Running warmup of function 9
V1208 17:05:21.055000 743734 torch/_inductor/cudagraph_trees.py:2213] [__cudagraphs] Recording function 9 of graph recording id 9

@mobicham
Copy link
Contributor

mobicham commented Dec 8, 2025

Btw can you also do a benchmark with speed, based on the logs, it seems this creates an issue with cudagraphs, I am not sure though if this is critical

@i3hz There's performance regression with this change using torch.compile with batch_size > 16:

batch_size = 64

no-compile: 200 tokens/sec
compile + old static cache: 240 tokens/sec
compile + new static cache: 97 tokens/sec

you can reproduce it:

import torch
from transformers import WhisperForConditionalGeneration, AutoProcessor
from transformers.cache_utils import StaticCache, EncoderDecoderCache
import math
from tqdm import tqdm
import numpy as np

device = 'cuda:0'
torch_dtype = torch.float16
model_id = "openai/whisper-large-v3-turbo"

model = WhisperForConditionalGeneration.from_pretrained(model_id, torch_dtype=torch_dtype, attn_implementation="sdpa", device_map=device) 
model.generation_config.cache_implementation = "static"

@torch.no_grad()
def run_encoder_raw(model, labels, encoder_outputs, past_key_values, prefill: bool):

    seq_length = labels.shape[-1]
    if(prefill):
        cache_position = torch.arange(seq_length, device=device)
    else:
        cache_position = torch.tensor([seq_length], device=device)

    out_decoder = model.model.decoder(
        labels, 
        encoder_hidden_states=encoder_outputs,
        past_key_values = past_key_values,
        cache_position=cache_position,
        use_cache = True,
        return_dict=True,
    ) 

    cur_token = model.proj_out(out_decoder.last_hidden_state[:,-1:]).argmax(axis=-1) 
    past_key_values = out_decoder.past_key_values

    return cur_token, past_key_values

max_batch_size = 32
max_cache_len = 256
enc_len = 1500
decoder = model.model.decoder

def create_cache(max_batch_size):
    # Cache
    self_cache = StaticCache(
        config=decoder.config,
        max_batch_size=max_batch_size,
        max_cache_len=max_cache_len,
        device=device,
        dtype=torch_dtype,
    )

    cross_cache = StaticCache(
        config=decoder.config,
        max_batch_size=max_batch_size,
        max_cache_len=enc_len,
        device=device,
        dtype=torch_dtype,
    )

    return EncoderDecoderCache(self_cache, cross_cache)

max_batch_size = 64
past_key_values = create_cache(max_batch_size)

run_encoder = run_encoder_raw

for _ in range(2):
    for current_bs in [max_batch_size]: #[8, 4, 2, 1]:
        past_key_values.reset()
        assert current_bs <= max_batch_size, "batch_size should be <= max_batch_size"
        seq_length = 3
        labels = torch.tensor([[50258, 50259, 50360]] * current_bs, device=device, dtype=torch.int64)
        
        encoder_outputs = torch.randn([current_bs, enc_len, 1280], device=device, dtype=torch_dtype)

        cur_token, past_key_values_out = run_encoder(model, labels, encoder_outputs, past_key_values, prefill=True)

        # warm-uup
        for _ in range(3):
            cur_token, past_key_values_out = run_encoder(model, cur_token.clone(), encoder_outputs, past_key_values_out, prefill=False)
        torch.cuda.synchronize()

        for _ in tqdm(range(500)):
            cur_token, past_key_values_out = run_encoder(model, cur_token.clone(), encoder_outputs, past_key_values_out, prefill=False)     
            torch.cuda.synchronize()

        print(f"{current_bs} pass!")


    run_encoder = torch.compile(run_encoder_raw, mode='reduce-overhead', fullgraph=True)
    print('----- compiled run ---- ')

@i3hz
Copy link
Contributor Author

i3hz commented Dec 8, 2025

@i3hz There's performance regression with this change using torch.compile with batch_size > 16:

Yeah you're right , it's most likely from the slicing operations . Is there any way to optimize it ?

@mobicham
Copy link
Contributor

mobicham commented Dec 8, 2025

@i3hz it's weird, because that slicing is not doing any copy, I also tried .narrow() instead, same behavior. Might be related to torch.compile being used with self.keys_full / self.values_full instead of self.keys / self.values 🤔
because when I do the slicing with self.keys / self.values instead of using _full perf is unchanged but then again we have the issue of returns .keys / .values in decoding.

My assumption is that the issue comes from batch_size in update which is dynamic and introduces SymInts which messes with torch._dynamo.mark_static_address().

It seems to me there's no clean way of doing this properly. One thing that could work is initializing the full cache outside and pass it to lazy_initialization() for each new cache created with a given batch_size, I can try that tomorrow

@i3hz
Copy link
Contributor Author

i3hz commented Dec 9, 2025

We can do one thing is that if current batch size is equal to max batch size we just return the whole thing without slicing . But then again its not really a fix for the performance reduction but it does gives us similar performance for your repro script @mobicham .

The performance reduction is just related to the list slicing . I tried a basic script where I just compare 2 compiled functions one just returning the tensor , other slicing it and for 10k iterations it's like more than 2x slower .

So what can we do?

@mobicham
Copy link
Contributor

mobicham commented Dec 9, 2025

@i3hz yeah I am not sure, I asked on the GPU MODE Discord, let's see what the Torch folks say about this slicing op

@mobicham
Copy link
Contributor

mobicham commented Dec 9, 2025

@i3hz I made it work, no perf regression and vram usage in check ! Feel free to add the other checks and the comment.
Btw, I see that StaticCache does not pass max_batch_size to StaticLayer 🤔 :

layer = StaticLayer(max_cache_len=max_cache_len)

right now I hard-coded it to 128 for testing, but we need to get it from kwargs:

class StaticLayer(CacheLayerMixin):
    is_compileable = True
    is_sliding = False

    def __init__(self, max_cache_len: int):
        super().__init__()
        self.max_cache_len = max_cache_len
        self.max_batch_size = 128

    def init_full_cache(self, key_states: torch.Tensor):
        """ Init the full batch with max_batch_size  """
        _, self.num_heads, _, self.head_dim = key_states.shape
        self.dtype, self.device = key_states.dtype, key_states.device

        self.keys_full = torch.zeros(
            (self.max_batch_size, self.num_heads, self.max_cache_len, self.head_dim),
            dtype=self.dtype,
            device=self.device,
        )
        self.values_full = torch.zeros(
            (self.max_batch_size, self.num_heads, self.max_cache_len, self.head_dim),
            dtype=self.dtype,
            device=self.device,
        )
        
    def lazy_initialization(self, key_states: torch.Tensor):
        """ Mark each chunk with torch static address """
        self.init_full_cache(key_states)
        
        self.keys_map = [None] * (self.max_batch_size + 1)
        self.values_map = [None] * (self.max_batch_size + 1)
        for batch_size in range(1, self.max_batch_size + 1):
            self.keys_map[batch_size] = self.keys_full[:batch_size]
            self.values_map[batch_size] = self.values_full[:batch_size]
        
            if not is_torchdynamo_compiling():
                torch._dynamo.mark_static_address(self.keys_map[batch_size])
                torch._dynamo.mark_static_address(self.values_map[batch_size])

        self.is_initialized = True

    def update(
        self,
        key_states: torch.Tensor,
        value_states: torch.Tensor,
        cache_kwargs: Optional[dict[str, Any]] = None,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        
        if not self.is_initialized:
            self.lazy_initialization(key_states)

        cache_position = cache_kwargs.get("cache_position") if cache_kwargs is not None else None
        cache_position = (
            cache_position if cache_position is not None else torch.arange(key_states.shape[-2], device=self.device)
        )
        

        batch_size = key_states.shape[0]        
        self.keys = self.keys_map[batch_size]
        self.values = self.values_map[batch_size]
        self._current_batch_size = batch_size

        try:
            self.keys.index_copy_(2, cache_position, key_states)
            self.values.index_copy_(2, cache_position, value_states)
        except NotImplementedError:
            self.keys[:, :, cache_position] = key_states
            self.values[:, :, cache_position] = value_states
        return self.keys, self.values

    def reset(self):
        if self.is_initialized:
            self.keys_full.zero_()
            self.values_full.zero_()
            self._current_batch_size = None

    def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]:
        kv_offset = 0
        kv_length = self.max_cache_len
        return kv_length, kv_offset

    def get_seq_length(self) -> int:
        return (self.keys_full[0, 0].any(dim=-1)).sum() if self.is_initialized else 0

    def get_max_cache_shape(self) -> int:
        return self.max_cache_len
    

@i3hz
Copy link
Contributor Author

i3hz commented Dec 10, 2025

Btw, I see that StaticCache does not pass max_batch_size to StaticLayer 🤔

I've passed in this branch :)

I made it work, no perf regression and vram usage in check

Yeah LGTM . I think it will add some latency during the first forward pass for massive batch sizes so I'd just like to confirm with @zucchini-nlp before pushing the changes .

@mobicham
Copy link
Contributor

@i3hz I didn't see any particular increase in latency at cache creation, the extra stuff is the loop and that loop is not doing any copy, but I do agree we need confirmation from @zucchini-nlp

@zucchini-nlp
Copy link
Member

Hey, sorry again for delay. Interesting that the slicing op is causing issues, I was expecting that the sliced keys will share the same static address

The above workaround doesn't look bad to me, since we're still creating a single cache object. I wonder if marking it static across different batch sizes affects anything 🤔

@i3hz
Copy link
Contributor Author

i3hz commented Dec 13, 2025

self = <tests.utils.test_cache_utils.CacheExportIntegrationTest testMethod=test_hybrid_cache_exportability>

    @pytest.mark.torch_export_test
    def test_hybrid_cache_exportability(self):
        """
        Tests that static cache works with `torch.export()`
        """
        if not is_torch_greater_or_equal("2.6"):
            self.skipTest(reason="This test requires torch >= 2.6 to run.")

        from transformers.integrations.executorch import TorchExportableModuleForDecoderOnlyLM

        set_seed(0)
        model_id = "hf-internal-testing/tiny-random-Gemma3ForCausalLM"
        model = AutoModelForCausalLM.from_pretrained(model_id)
        model.eval()
        self.assertEqual(model.config.use_cache, True)

        # Export + hybrid StaticCache
        model.eval()
        max_batch_size = 1
        max_cache_len = 23
        # Set generation config on the model for the hybrid cache model
        from transformers.generation.configuration_utils import GenerationConfig

        model.generation_config = GenerationConfig(
            use_cache=True,
            cache_implementation="static",
            max_length=max_cache_len,
            cache_config={
                "batch_size": max_batch_size,
                "max_cache_len": max_cache_len,
                "device": model.device,
            },
        )
        exportable_module = TorchExportableModuleForDecoderOnlyLM(model)
        exported_program = exportable_module.export(
            input_ids=torch.tensor([[1]], dtype=torch.long, device=model.device),
            cache_position=torch.tensor([0], dtype=torch.long, device=model.device),
        )
        n_g_key_caches = n_g_value_caches = 0
        for buffer_name, buffer in exported_program.named_buffers():
            if buffer_name.startswith("key_cache"):
                self.assertTrue(buffer.shape[0] == max_batch_size)
                self.assertTrue(buffer.shape[2] == max_cache_len)
                n_g_key_caches = n_g_key_caches + 1
            if buffer_name.startswith("value_cache"):
                self.assertTrue(buffer.shape[0] == max_batch_size)
                self.assertTrue(buffer.shape[2] == max_cache_len)
                n_g_value_caches = n_g_value_caches + 1
>       self.assertEqual(n_g_key_caches, model.config.num_hidden_layers)
E       AssertionError: 0 != 2

tests/utils/test_cache_utils.py:808: AssertionError

Why is this test failing ? Is it because we're calling it keys_full instead of keys?

Anyways that's the only test that's failing .
@mobicham Can you please confirm that this aligns with your fix? Thank you

@mobicham
Copy link
Contributor

@i3hz sorry for the late answer. I haven't run the test but I guess buffer.shape[0] == max_batch_size is the part that is failing? I will try to get back to this over the weekend 🙏

@i3hz
Copy link
Contributor Author

i3hz commented Dec 18, 2025

@i3hz sorry for the late answer. I haven't run the test but I guess buffer.shape[0] == max_batch_size is the part that is failing? I will try to get back to this over the weekend 🙏

Yep that's the one . All good take your time (i'm also on break lol)

@i3hz
Copy link
Contributor Author

i3hz commented Jan 5, 2026

@zucchini-nlp should i close this PR?

@mobicham
Copy link
Contributor

mobicham commented Jan 5, 2026

Apologies for the delay, I just got back from a break.

The issue in the test is that, by design, the cache needs have max_batch_size dim 0.
What we could to relax the test is to use self.assertTrue(buffer.shape[0] <= max_batch_size) .

What do you think?

@i3hz
Copy link
Contributor Author

i3hz commented Jan 10, 2026

What we could to relax the test is to use self.assertTrue(buffer.shape[0] <= max_batch_size) .

That does sound reasonable but I'd like to confirm with @zucchini-nlp before making that change .

Does it pass the test after the change?

@zucchini-nlp
Copy link
Member

Yeah, ofc, feel free to adjust tests as it is needed

@i3hz
Copy link
Contributor Author

i3hz commented Jan 13, 2026

        exported_program = exportable_module.export(
            input_ids=torch.tensor([[1]], dtype=torch.long, device=model.device),
            cache_position=torch.tensor([0], dtype=torch.long, device=model.device),
        )
        n_g_key_caches = n_g_value_caches = 0
        for buffer_name, buffer in exported_program.named_buffers():
            if buffer_name.startswith("key_cache"):
                self.assertTrue(buffer.shape[0] <= max_batch_size)
                self.assertTrue(buffer.shape[2] == max_cache_len)
                n_g_key_caches = n_g_key_caches + 1
            if buffer_name.startswith("value_cache"):
                self.assertTrue(buffer.shape[0] <= max_batch_size)
                self.assertTrue(buffer.shape[2] == max_cache_len)
                n_g_value_caches = n_g_value_caches + 1
>       self.assertEqual(n_g_key_caches, model.config.num_hidden_layers)
E       AssertionError: 0 != 2

tests/utils/test_cache_utils.py:808: AssertionError
------------------------------------------------------------------- Captured stderr call --------------------------------------------------------------------
Unrecognized keys in `rope_parameters` for 'rope_type'='default': {'sliding_attention', 'full_attention'}
use_kernel_func_from_hub is not available in the installed kernels version. Please upgrade kernels to use this feature.
Warning: You are sending unauthenticated requests to the HF Hub. Please set a HF_TOKEN to enable higher rate limits and faster downloads.
Loading weights: 100%|██████████| 28/28 [00:00<00:00, 1197.21it/s, Materializing param=model.norm.weight]
--------------------------------------------------------------------- Captured log call ---------------------------------------------------------------------
WARNING  huggingface_hub.utils._http:_http.py:779 Warning: You are sending unauthenticated requests to the HF Hub. Please set a HF_TOKEN to enable higher rate limits and faster downloads.
===================================================================== warnings summary ======================================================================
<frozen importlib._bootstrap>:488
  <frozen importlib._bootstrap>:488: DeprecationWarning: builtin type SwigPyPacked has no __module__ attribute

<frozen importlib._bootstrap>:488
  <frozen importlib._bootstrap>:488: DeprecationWarning: builtin type SwigPyObject has no __module__ attribute

tests/utils/test_cache_utils.py::CacheExportIntegrationTest::test_hybrid_cache_exportability
  /home/vedth/stuhdy/transformers/venv/lib/python3.12/site-packages/torch/_dynamo/output_graph.py:1711: UserWarning: While exporting, we found certain side effects happened in the model.forward. Here are the list of potential sources you can double check: ["L['self'].cache.layers[0]", "L['self'].cache.layers[1]"]
    warnings.warn(

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
================================================================== short test summary info ==================================================================
FAILED tests/utils/test_cache_utils.py::CacheExportIntegrationTest::test_hybrid_cache_exportability - AssertionError: 0 != 2
=============================================================== 1 failed, 3 warnings in 5.13s ===============================================================
sys:1: DeprecationWarning: builtin type swigvarlink has no __module__ attribute

@mobicham The test doesn't pass after adding the <= or am I missing something ?

@i3hz
Copy link
Contributor Author

i3hz commented Jan 13, 2026

Another thing I found is on main
num_hidden_layers: 2
n_g_key_caches: 2

but on this branch
n_g_key_caches: 0
num_hidden_layers: 2

Why isn't the cache being registered?

@i3hz
Copy link
Contributor Author

i3hz commented Jan 14, 2026

If we move the slicing in lazy_initialization instead of update it fixes the issue without having to change the test

def lazy_initialization(self, key_states: torch.Tensor):
        """
        Lazy initialization of the keys and values tensors. This allows to get all properties (dtype, device,
        num_heads in case of TP etc...) at runtime directly, which is extremely practical as it avoids moving
        devices, dtypes etc later on for each `update` (which could break the static dynamo addresses as well).

        If this is unwanted, one can call `early_initialization(...)` on the Cache directly, which will call this
        function ahead-of-time (this is required for `torch.export` for example). Note that for `compile`, as we
        internally don't compile the prefill, this is guaranteed to have been called already when compiling.
        If compiling the prefill as well, e.g. calling `model.compile(...)` before `generate` with a static cache,
        it is still supported in general, but without guarantees depending on the compilation options (e.g. cuda graphs,
        i.e. `mode="reduce-overhead"` is known to fail). But it will in general work correctly, and prefill should
        not be compiled anyway for performances!
        """

        self.init_full_cache(key_states)

        self.keys_map = [None] * (self.max_batch_size + 1)
        self.values_map = [None] * (self.max_batch_size + 1)
        for batch_size in range(1, self.max_batch_size + 1):
            self.keys_map[batch_size] = self.keys_full[:batch_size]
            self.values_map[batch_size] = self.values_full[:batch_size]
            # Note: `mark_static_address` is used to tag the cache as a fixed data pointer, preventing compiled graph
            # breaks when updating the cache. However, it is not supported when tracing the graph, so we skip it in this case.
            # As prefill should never be compiled, this is not an issue and it will still be run (except when users compile
            # prefill explicitly, but this should be avoided!)
            if not is_torchdynamo_compiling():
                torch._dynamo.mark_static_address(self.keys_map[batch_size])
                torch._dynamo.mark_static_address(self.values_map[batch_size])
        self.is_initialized = True
        batch_size = key_states.shape[0]

        # Slice to current batch size
        self.keys = self.keys_map[batch_size]
        self.values = self.values_map[batch_size]

What do you think? I think this is due to TorchExportableModuleWithHybridCache trying to register the cache but since it doesn't exist yet it just fails .

@i3hz
Copy link
Contributor Author

i3hz commented Jan 14, 2026

import torch
from transformers import WhisperForConditionalGeneration, AutoProcessor
from transformers.cache_utils import StaticCache, EncoderDecoderCache
import math
from tqdm import tqdm
import numpy as np

device = 'cuda:0'
torch_dtype = torch.float16
model_id = "openai/whisper-large-v3-turbo"

model = WhisperForConditionalGeneration.from_pretrained(model_id, torch_dtype=torch_dtype, attn_implementation="sdpa", device_map=device) 
model.generation_config.cache_implementation = "static"

@torch.no_grad()
def run_encoder_raw(model, labels, encoder_outputs, past_key_values, prefill: bool):

    seq_length = labels.shape[-1]
    if(prefill):
        cache_position = torch.arange(seq_length, device=device)
    else:
        cache_position = torch.tensor([seq_length], device=device)

    out_decoder = model.model.decoder(
        labels, 
        encoder_hidden_states=encoder_outputs,
        past_key_values = past_key_values,
        cache_position=cache_position,
        use_cache = True,
        return_dict=True,
    ) 

    cur_token = model.proj_out(out_decoder.last_hidden_state[:,-1:]).argmax(axis=-1) 
    past_key_values = out_decoder.past_key_values

    return cur_token, past_key_values

max_batch_size = 32
max_cache_len = 256
enc_len = 1500
decoder = model.model.decoder

def create_cache(max_batch_size):
    # Cache
    self_cache = StaticCache(
        config=decoder.config,
        max_batch_size=max_batch_size,
        max_cache_len=max_cache_len,
        device=device,
        dtype=torch_dtype,
    )

    cross_cache = StaticCache(
        config=decoder.config,
        max_batch_size=max_batch_size,
        max_cache_len=enc_len,
        device=device,
        dtype=torch_dtype,
    )

    return EncoderDecoderCache(self_cache, cross_cache)

max_batch_size = 64
past_key_values = create_cache(max_batch_size)

run_encoder = run_encoder_raw

for _ in range(2):
    for current_bs in [max_batch_size]: #[8, 4, 2, 1]:
        past_key_values.reset()
        assert current_bs <= max_batch_size, "batch_size should be <= max_batch_size"
        seq_length = 3
        labels = torch.tensor([[50258, 50259, 50360]] * current_bs, device=device, dtype=torch.int64)
        
        encoder_outputs = torch.randn([current_bs, enc_len, 1280], device=device, dtype=torch_dtype)

        cur_token, past_key_values_out = run_encoder(model, labels, encoder_outputs, past_key_values, prefill=True)

        # warm-uup
        for _ in range(3):
            cur_token, past_key_values_out = run_encoder(model, cur_token.clone(), encoder_outputs, past_key_values_out, prefill=False)
        torch.cuda.synchronize()

        for _ in tqdm(range(500)):
            cur_token, past_key_values_out = run_encoder(model, cur_token.clone(), encoder_outputs, past_key_values_out, prefill=False)     
            torch.cuda.synchronize()

        print(f"{current_bs} pass!")


    run_encoder = torch.compile(run_encoder_raw, mode='reduce-overhead', fullgraph=True)
    print('----- compiled run ---- ')

I don't see any big difference in performance after using this either.

@mobicham
Copy link
Contributor

mobicham commented Jan 14, 2026

Thanks @i3hz that's good new, I just tested it and it works 👍 , I also don't see a difference in perf, so should be good to go.

You should test with for current_bs in [1, 2] btw, that's where you'll see a big difference in decoding speed because it's memory-bound. for current_bs=1 I goes from ~266 tk/sec to 1100 tk/sec with compile.
This is what I tested (via monkey patching):

import torch
from transformers import WhisperForConditionalGeneration, AutoProcessor
from transformers.cache_utils import StaticCache, EncoderDecoderCache
import math
from tqdm import tqdm
import numpy as np

device = 'cuda:0'
torch_dtype = torch.float16
model_id = "openai/whisper-large-v3-turbo"

model = WhisperForConditionalGeneration.from_pretrained(model_id, torch_dtype=torch_dtype, attn_implementation="sdpa", device_map=device) 
model.generation_config.cache_implementation = "static"

@torch.no_grad()
def run_encoder_raw(model, labels, encoder_outputs, past_key_values, prefill: bool):

    seq_length = labels.shape[-1]
    if(prefill):
        cache_position = torch.arange(seq_length, device=device)
    else:
        cache_position = torch.tensor([seq_length], device=device)

    out_decoder = model.model.decoder(
        labels, 
        encoder_hidden_states=encoder_outputs,
        past_key_values = past_key_values,
        cache_position=cache_position,
        use_cache = True,
        return_dict=True,
    ) 

    cur_token = model.proj_out(out_decoder.last_hidden_state[:,-1:]).argmax(axis=-1) 
    past_key_values = out_decoder.past_key_values

    return cur_token, past_key_values

max_batch_size = 32
max_cache_len = 256
enc_len = 1500
decoder = model.model.decoder

##################################################################################################
#Monkey patch
from typing import Optional, Any
from transformers import cache_utils
from transformers.cache_utils import is_torchdynamo_compiling, CacheLayerMixin

class StaticLayer(CacheLayerMixin):
    is_compileable = True
    is_sliding = False

    def __init__(self, max_cache_len: int):
        super().__init__()
        self.max_cache_len = max_cache_len
        self.max_batch_size = 128

    def init_full_cache(self, key_states: torch.Tensor):
        """ Init the full batch with max_batch_size  """
        _, self.num_heads, _, self.head_dim = key_states.shape
        self.dtype, self.device = key_states.dtype, key_states.device

        self.keys_full = torch.zeros(
            (self.max_batch_size, self.num_heads, self.max_cache_len, self.head_dim),
            dtype=self.dtype,
            device=self.device,
        )
        self.values_full = torch.zeros(
            (self.max_batch_size, self.num_heads, self.max_cache_len, self.head_dim),
            dtype=self.dtype,
            device=self.device,
        )

    def lazy_initialization(self, key_states: torch.Tensor):
        self.init_full_cache(key_states)

        self.keys_map = [None] * (self.max_batch_size + 1)
        self.values_map = [None] * (self.max_batch_size + 1)
        for batch_size in range(1, self.max_batch_size + 1):
            self.keys_map[batch_size] = self.keys_full[:batch_size]
            self.values_map[batch_size] = self.values_full[:batch_size]
            if not is_torchdynamo_compiling():
                torch._dynamo.mark_static_address(self.keys_map[batch_size])
                torch._dynamo.mark_static_address(self.values_map[batch_size])
        self.is_initialized = True
        batch_size = key_states.shape[0]

        # Slice to current batch size
        self.keys = self.keys_map[batch_size]
        self.values = self.values_map[batch_size]

    def update(
        self,
        key_states: torch.Tensor,
        value_states: torch.Tensor,
        cache_kwargs: Optional[dict[str, Any]] = None,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        
        if not self.is_initialized:
            self.lazy_initialization(key_states)

        cache_position = cache_kwargs.get("cache_position") if cache_kwargs is not None else None
        cache_position = (
            cache_position if cache_position is not None else torch.arange(key_states.shape[-2], device=self.device)
        )
        

        batch_size = key_states.shape[0]        
        self.keys = self.keys_map[batch_size]
        self.values = self.values_map[batch_size]
        self._current_batch_size = batch_size

        try:
            self.keys.index_copy_(2, cache_position, key_states)
            self.values.index_copy_(2, cache_position, value_states)
        except NotImplementedError:
            self.keys[:, :, cache_position] = key_states
            self.values[:, :, cache_position] = value_states
        return self.keys, self.values

    def reset(self):
        if self.is_initialized:
            self.keys_full.zero_()
            self.values_full.zero_()
            self._current_batch_size = None

    def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]:
        kv_offset = 0
        kv_length = self.max_cache_len
        return kv_length, kv_offset

    def get_seq_length(self) -> int:
        return (self.keys_full[0, 0].any(dim=-1)).sum() if self.is_initialized else 0

    def get_max_cache_shape(self) -> int:
        return self.max_cache_len

cache_utils.StaticLayer = StaticLayer
##################################################################################################

def create_cache(max_batch_size):
    # Cache
    self_cache = StaticCache(
        config=decoder.config,
        max_batch_size=max_batch_size,
        max_cache_len=max_cache_len,
        device=device,
        dtype=torch_dtype,
    )

    cross_cache = StaticCache(
        config=decoder.config,
        max_batch_size=max_batch_size,
        max_cache_len=enc_len,
        device=device,
        dtype=torch_dtype,
    )

    return EncoderDecoderCache(self_cache, cross_cache)

max_batch_size = 64
past_key_values = create_cache(max_batch_size)

run_encoder = run_encoder_raw

for _ in range(2):
    for current_bs in [1]: #[8, 4, 2, 1]:
        past_key_values.reset()
        assert current_bs <= max_batch_size, "batch_size should be <= max_batch_size"
        seq_length = 3
        labels = torch.tensor([[50258, 50259, 50360]] * current_bs, device=device, dtype=torch.int64)
        
        encoder_outputs = torch.randn([current_bs, enc_len, 1280], device=device, dtype=torch_dtype)

        cur_token, past_key_values_out = run_encoder(model, labels, encoder_outputs, past_key_values, prefill=True)

        # warm-uup
        for _ in range(3):
            cur_token, past_key_values_out = run_encoder(model, cur_token.clone(), encoder_outputs, past_key_values_out, prefill=False)
        torch.cuda.synchronize()

        for _ in tqdm(range(500)):
            cur_token, past_key_values_out = run_encoder(model, cur_token.clone(), encoder_outputs, past_key_values_out, prefill=False)     
            torch.cuda.synchronize()

        print(f"{current_bs} pass!")


    run_encoder = torch.compile(run_encoder_raw, mode='reduce-overhead', fullgraph=True)
    print('----- compiled run ---- ')

Output:

`torch_dtype` is deprecated! Use `dtype` instead!
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 500/500 [00:01<00:00, 266.25it/s]
1 pass!
----- compiled run ---- 
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 500/500 [00:00<00:00, 1124.07it/s]
1 pass!
----- compiled run ----

@i3hz
Copy link
Contributor Author

i3hz commented Jan 15, 2026

Thanks a lot for your feedback @mobicham. What should I add as a test,should it just be like a similar version of the benchmarking script ?
And what is this PR slow CI test ? I didn't make any changes in the workflow

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

StaticCache crashes when the batch-size changes

3 participants