-
Notifications
You must be signed in to change notification settings - Fork 31.7k
Fixes StaticCache Crashes #42467
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Fixes StaticCache Crashes #42467
Conversation
|
@zucchini-nlp This doesn't have the |
zucchini-nlp
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for a quick fix @i3hz !
I think we need to allow max_batch_size which will take precedence if available when lazy initializing the cache. Early cache initialization is currently used only in export, but we can allow users to re-use cache across several generation with max batch size. It would also require us to change a few places in generation imo
After that, we can verify that there are no unwanted graph breaks and run the bench
LMK if this makes sense and you need guidance
So basically I should add And also change line 1837 from or cache_to_check.max_batch_size < batch_size |
|
Yep, and a small test as well |
|
hi @zucchini-nlp I'm still stuck on this. I’ve been testing with |
src/transformers/cache_utils.py
Outdated
| k_out = self.keys | ||
| v_out = self.values | ||
| batch_size = key_states.shape[0] | ||
| if k_out.shape[0] != batch_size: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess k_out.shape[0] >= batch_size is better
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When debugging the torch.compile stuff, can you check this:
assert k_out.data_ptr() == k_out[:batch_size].data_ptr() , "invalid k_out data copy()!"
assert v_out.data_ptr() == v_out[:batch_size].data_ptr() , "invalid v_out data copy()!"If there's no copy, I don't see why Cudagraphs would break with Whisper.
What error do you get exactly btw?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess k_out.shape[0] <= batch_size is better
Wait should it be < or > considering k_out will be larger than the batch_size
assert k_out.data_ptr() == k_out[:batch_size].data_ptr() , "invalid k_out data copy()!"
Builtin `operator.*` comparison with constant `self` failed
Explanation: Failed to compare DataPtrVariable() with DataPtrVariable(), because DataPtrVariable() is not a Python constant or its mutation check fails.
About the actual torch.compile error i'm getting i'm trying max_batch_size = 8 and the list being 8,4,2,1
on 4 it crashes with
Dynamo failed to run FX node with fake tensors: call_function <built-in function scaled_dot_product_attention>(*(FakeTensor(..., device='cuda:0', size=(s72, 6, 1, 64), dtype=torch.float16,
grad_fn=<TransposeBackward0>), FakeTensor(..., device='cuda:0', size=(8, 6, 32, 64), dtype=torch.float16,
grad_fn=<Error>), FakeTensor(..., device='cuda:0', size=(8, 6, 32, 64), dtype=torch.float16,
grad_fn=<Error>)), **{'attn_mask': None, 'dropout_p': 0.0, 'scale': 1.0, 'is_causal': False}): got RuntimeError('Attempting to broadcast a dimension of length 8 at -2! Mismatching argument at index 1 had [8, 6]; but expected shape should be broadcastable to [s72, 6]')
from user code:
File "/home/vedth/stuhdy/z.py", line 21, in decoder_forward
out = model.model.decoder(
File "/home/vedth/stuhdy/transformers/src/transformers/models/whisper/modeling_whisper.py", line 865, in forward
layer_outputs = decoder_layer(
File "/home/vedth/stuhdy/transformers/src/transformers/modeling_layers.py", line 94, in __call__
return super().__call__(*args, **kwargs)
File "/home/vedth/stuhdy/transformers/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/vedth/stuhdy/transformers/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
return forward_call(*args, **kwargs)
File "/home/vedth/stuhdy/transformers/src/transformers/models/whisper/modeling_whisper.py", line 501, in forward
hidden_states, cross_attn_weights = self.encoder_attn(
File "/home/vedth/stuhdy/transformers/src/transformers/models/whisper/modeling_whisper.py", line 347, in forward
attn_output, attn_weights = attention_interface(
File "/home/vedth/stuhdy/transformers/src/transformers/integrations/sdpa_attention.py", line 92, in sdpa_attention_forward
attn_output = torch.nn.functional.scaled_dot_product_attention(
Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wait should it be < or > considering k_out will be larger than the batch_size
Oh sorry you're right, I meant current_batch_size <= max_batch_size
assert k_out.data_ptr() == k_out[:batch_size].data_ptr() , "invalid k_out data copy()!"
I meant run it without torch.compile, just to see if it performs any copy
Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace
I see, I will try to debug next week too 👍
|
@i3hz I will also do some debugging next week |
Thanks a lot |
|
@i3hz I tried the slicing solution but it throws an attention error even without torch.compile: Something is strange, this works: for bs in [8, 4, 2, 1]:
past_key_values = create_cache(max_batch_size)
...but when the cache is allocated only once, it throws that error: past_key_values = create_cache(max_batch_size)
for bs in [8, 4, 2, 1]:
...import torch
from transformers import WhisperForConditionalGeneration, AutoProcessor
from transformers.cache_utils import StaticCache, EncoderDecoderCache
import math
import numpy as np
device = 'cuda:0'
torch_dtype = torch.float16
model_id = "openai/whisper-large-v3-turbo"
model = WhisperForConditionalGeneration.from_pretrained(model_id, torch_dtype=torch_dtype, attn_implementation="sdpa", device_map=device)
model.generation_config.cache_implementation = "static"
@torch.no_grad()
def run_encoder(model, labels, encoder_outputs, past_key_values, prefill: bool):
seq_length = labels.shape[-1]
if(prefill):
cache_position = torch.arange(seq_length, device=device)
else:
cache_position = torch.tensor([seq_length], device=device)
out_decoder = model.model.decoder(
labels,
encoder_hidden_states=encoder_outputs,
past_key_values = past_key_values,
cache_position=cache_position,
use_cache = True,
return_dict=True,
)
cur_token = model.proj_out(out_decoder.last_hidden_state[:,-1:]).argmax(axis=-1)
past_key_values = out_decoder.past_key_values
return cur_token, past_key_values
max_batch_size = 32
max_cache_len = 256
enc_len = 1500
decoder = model.model.decoder
################################################
from transformers import cache_utils
from typing import Optional, Any
def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
cache_kwargs: Optional[dict[str, Any]] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
if not self.is_initialized:
self.lazy_initialization(key_states)
cache_position = cache_kwargs.get("cache_position") if cache_kwargs is not None else None
cache_position = (
cache_position if cache_position is not None else torch.arange(key_states.shape[-2], device=self.device)
)
batch_size = key_states.shape[0]
assert batch_size <= self.max_batch_size, f"Current batch-size {batch_size} should be <= max_batch_size ({self.max_batch_size})"
print(f"{batch_size}:{self.max_batch_size}")
k_out = self.keys[:batch_size]
v_out = self.values[:batch_size]
# Update the cache
try:
k_out.index_copy_(2, cache_position, key_states)
v_out.index_copy_(2, cache_position, value_states)
except NotImplementedError:
# Fallback for devices like MPS where index_copy_ might not be supported.
k_out[:, :, cache_position] = key_states
v_out[:, :, cache_position] = value_states
return k_out, v_out
cache_utils.StaticLayer.update = update
################################################
def create_cache(max_batch_size):
# Cache
self_cache = StaticCache(
config=decoder.config,
max_batch_size=max_batch_size,
max_cache_len=max_cache_len,
device=device,
dtype=torch_dtype,
)
cross_cache = StaticCache(
config=decoder.config,
max_batch_size=max_batch_size,
max_cache_len=enc_len,
device=device,
dtype=torch_dtype,
)
return EncoderDecoderCache(self_cache, cross_cache)
#torch._dynamo.config.capture_scalar_outputs = True
#run_encoder = torch.compile(run_encoder, mode='reduce-overhead', fullgraph=True)
max_batch_size = 8
past_key_values = create_cache(max_batch_size)
for bs in [8, 4, 2, 1]:
assert bs <= max_batch_size, "batch_size should be <= max_batch_size"
seq_length = 3
labels = torch.tensor([[50258, 50259, 50360]] * bs, device=device, dtype=torch.int64)
encoder_outputs = torch.randn([bs, enc_len, 1280], device=device, dtype=torch_dtype)
cur_token, past_key_values_out = run_encoder(model, labels, encoder_outputs, past_key_values, prefill=True)
cur_token, past_key_values_out = run_encoder(model, cur_token.clone(), encoder_outputs.clone(), past_key_values_out, prefill=False) |
|
So this issue is this part transformers/src/transformers/models/whisper/modeling_whisper.py Lines 329 to 330 in 7f5c209
if you replace it with this, it works. key_states = past_key_values.layers[self.layer_idx].keys[:bsz]
value_states = past_key_values.layers[self.layer_idx].values[:bsz]However, the problem is that we can't do this for every modeling file separately. I guess the solution is to do something with class StaticLayer(CacheLayerMixin):
"""
A static cache layer that stores the key and value states as static tensors of shape `[batch_size, num_heads, max_cache_len), head_dim]`.
It lazily allocates its full backing tensors, and then mutates them in-place. Built for `torch.compile` support.
Args:
max_cache_len (`int`):
Maximum number of tokens that can be stored, used for tensor preallocation.
"""
is_compileable = True
is_sliding = False
def __init__(self, max_cache_len: int):
super().__init__()
self.max_cache_len = max_cache_len
def lazy_initialization(self, key_states: torch.Tensor):
"""
Lazy initialization of the keys and values tensors. This allows to get all properties (dtype, device,
num_heads in case of TP etc...) at runtime directly, which is extremely practical as it avoids moving
devices, dtypes etc later on for each `update` (which could break the static dynamo addresses as well).
If this is unwanted, one can call `early_initialization(...)` on the Cache directly, which will call this
function ahead-of-time (this is required for `torch.export` for example). Note that for `compile`, as we
internally don't compile the prefill, this is guaranteed to have been called already when compiling.
If compiling the prefill as well, e.g. calling `model.compile(...)` before `generate` with a static cache,
it is still supported in general, but without guarantees depending on the compilation options (e.g. cuda graphs,
i.e. `mode="reduce-overhead"` is known to fail). But it will in general work correctly, and prefill should
not be compiled anyway for performances!
"""
self.max_batch_size, self.num_heads, _, self.head_dim = key_states.shape
self.dtype, self.device = key_states.dtype, key_states.device
self.keys_ = torch.zeros(
(self.max_batch_size, self.num_heads, self.max_cache_len, self.head_dim),
dtype=self.dtype,
device=self.device,
)
self.values_ = torch.zeros(
(self.max_batch_size, self.num_heads, self.max_cache_len, self.head_dim),
dtype=self.dtype,
device=self.device,
)
self.keys, self.values = self.keys_, self.values_
# Note: `mark_static_address` is used to tag the cache as a fixed data pointer, preventing compiled graph
# breaks when updating the cache. However, it is not supported when tracing the graph, so we skip it in this case.
# As prefill should never be compiled, this is not an issue and it will still be run (except when users compile
# prefill explicitly, but this should be avoided!)
if not is_torchdynamo_compiling():
torch._dynamo.mark_static_address(self.keys_)
torch._dynamo.mark_static_address(self.values_)
self.is_initialized = True
def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
cache_kwargs: Optional[dict[str, Any]] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Update the key and value caches in-place, and return the necessary keys and value states.
Args:
key_states (`torch.Tensor`): The new key states to cache.
value_states (`torch.Tensor`): The new value states to cache.
cache_kwargs (`dict[str, Any]`, *optional*): Additional arguments for the cache.
Returns:
tuple[`torch.Tensor`, `torch.Tensor`]: The key and value states.
"""
# Lazy initialization
if not self.is_initialized:
self.lazy_initialization(key_states)
# Some old models give None for `cache_position` or even omit passing `cache_kwargs` when used as cross-attention,
# in which case we should copy the whole Layer (key_states.shape[-2] == self.max_cache_len)
cache_position = cache_kwargs.get("cache_position") if cache_kwargs is not None else None
cache_position = (
cache_position if cache_position is not None else torch.arange(key_states.shape[-2], device=self.device)
)
batch_size = key_states.shape[0]
assert batch_size <= self.max_batch_size, f"Current batch-size {batch_size} should be <= max_batch_size ({self.max_batch_size})"
self.keys = self.keys_[:batch_size]
self.values = self.values_[:batch_size]
# Update the cache
try:
self.keys.index_copy_(2, cache_position, key_states)
self.values.index_copy_(2, cache_position, value_states)
except NotImplementedError:
# Fallback for devices like MPS where index_copy_ might not be supported.
self.keys[:, :, cache_position] = key_states
self.values[:, :, cache_position] = value_states
return self.keys, self.values
def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]:
"""Return the length and offset of the cache, used to generate the attention mask"""
kv_offset = 0
kv_length = self.max_cache_len
return kv_length, kv_offset
def get_seq_length(self) -> int:
"""Returns the sequence length of the cached states."""
# Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's
# limit the check to the first batch member and head dimension.
return (self.keys[0, 0].any(dim=-1)).sum() if self.is_initialized else 0
def get_max_cache_shape(self) -> int:
"""Return the maximum cache shape of the cache"""
return self.max_cache_len |
You are right it's not working :c . My reproduction script which uses gpt2 instead of whisper if you need it (which does successfully run ) import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.cache_utils import StaticCache
device = "cuda"
model_id = "openai-community/gpt2"
model = AutoModelForCausalLM.from_pretrained(
model_id,
dtype=torch.float16,
attn_implementation="sdpa",
ignore_mismatched_sizes=True
).to(device)
model.eval()
def decode_step(model, input_ids, past_key_values, cache_position):
out = model(
input_ids=input_ids,
past_key_values=past_key_values,
cache_position=cache_position,
use_cache=True,
)
logits = out.logits
next_token = logits[:, -1, :].argmax(dim=-1, keepdim=True)
return next_token, out.past_key_values
compiled_decode = torch.compile(decode_step, mode="reduce-overhead", fullgraph=True)
max_batch_size = 8
max_seq_len = 64
dtype = torch.float16
past_key_values = StaticCache(
config=model.config,
max_batch_size=max_batch_size,
max_cache_len=max_seq_len,
device=device,
dtype=dtype
)
batch_sizes = [8, 4, 2, 1]
try:
for bs in batch_sizes:
print(f"Batch Size: {bs}")
past_key_values.reset()
seq_len = 3
input_ids = torch.randint(0, 1000, (bs, seq_len), device=device)
cache_position = torch.arange(seq_len, device=device)
with torch.no_grad():
out = model(input_ids, past_key_values=past_key_values, cache_position=cache_position)
cur_token = out.logits[:, -1, :].argmax(dim=-1, keepdim=True)
cache_position = torch.tensor([seq_len], device=device)
cur_token, _ = compiled_decode(model, cur_token, past_key_values, cache_position)
print("Success")
except Exception as e:
print(f"Failed on {bs} with error {e}") |
|
@i3hz yeah because the issue is that, at some point it returns |
|
I've implemented the self.keys_ and self.values_ functionality . So the testing script from earlier does work . But torch compile still fails with a segmentation fault which I'm working on . (misclicked and accidentally closed the pr mb) |
class StaticLayer(CacheLayerMixin):
"""
A static cache layer that stores the key and value states as static tensors of shape `[batch_size, num_heads, max_cache_len), head_dim]`.
It lazily allocates its full backing tensors, and then mutates them in-place. Built for `torch.compile` support.
Args:
max_cache_len (`int`):
Maximum number of tokens that can be stored, used for tensor preallocation.
"""
is_compileable = True
is_sliding = False
def __init__(self, max_cache_len: int, max_batch_size: int | None = None):
super().__init__()
self.max_cache_len = max_cache_len
self.max_batch_size = max_batch_size
def lazy_initialization(self, key_states: torch.Tensor):
"""
Lazy initialization of the keys and values tensors. This allows to get all properties (dtype, device,
num_heads in case of TP etc...) at runtime directly, which is extremely practical as it avoids moving
devices, dtypes etc later on for each `update` (which could break the static dynamo addresses as well).
If this is unwanted, one can call `early_initialization(...)` on the Cache directly, which will call this
function ahead-of-time (this is required for `torch.export` for example). Note that for `compile`, as we
internally don't compile the prefill, this is guaranteed to have been called already when compiling.
If compiling the prefill as well, e.g. calling `model.compile(...)` before `generate` with a static cache,
it is still supported in general, but without guarantees depending on the compilation options (e.g. cuda graphs,
i.e. `mode="reduce-overhead"` is known to fail). But it will in general work correctly, and prefill should
not be compiled anyway for performances!
"""
if self.max_batch_size is None:
self.max_batch_size = key_states.shape[0]
_, self.num_heads, _, self.head_dim = key_states.shape
self.dtype, self.device = key_states.dtype, key_states.device
self.keys_ = torch.zeros(
(self.max_batch_size, self.num_heads, self.max_cache_len, self.head_dim),
dtype=self.dtype,
device=self.device,
)
self.values_ = torch.zeros(
(self.max_batch_size, self.num_heads, self.max_cache_len, self.head_dim),
dtype=self.dtype,
device=self.device,
)
self.keys = self.keys_
self.values = self.values_
# Note: `mark_static_address` is used to tag the cache as a fixed data pointer, preventing compiled graph
# breaks when updating the cache. However, it is not supported when tracing the graph, so we skip it in this case.
# As prefill should never be compiled, this is not an issue and it will still be run (except when users compile
# prefill explicitly, but this should be avoided!)
if not is_torchdynamo_compiling():
torch._dynamo.mark_static_address(self.keys_)
torch._dynamo.mark_static_address(self.values_)
self.is_initialized = True
def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
cache_kwargs: Optional[dict[str, Any]] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Update the key and value caches in-place, and return the necessary keys and value states.
Args:
key_states (`torch.Tensor`): The new key states to cache.
value_states (`torch.Tensor`): The new value states to cache.
cache_kwargs (`dict[str, Any]`, *optional*): Additional arguments for the cache.
Returns:
tuple[`torch.Tensor`, `torch.Tensor`]: The key and value states.
"""
# Lazy initialization
if not self.is_initialized:
self.lazy_initialization(key_states)
# Some old models give None for `cache_position` or even omit passing `cache_kwargs` when used as cross-attention,
# in which case we should copy the whole Layer (key_states.shape[-2] == self.max_cache_len)
cache_position = cache_kwargs.get("cache_position") if cache_kwargs is not None else None
cache_position = (
cache_position if cache_position is not None else torch.arange(key_states.shape[-2], device=self.device)
)
batch_size = key_states.shape[0]
# 3. Dynamic Slicing: Update the view to match current batch
self.keys = self.keys_[:batch_size]
self.values = self.values_[:batch_size]
try:
self.keys.index_copy_(2, cache_position, key_states)
self.values.index_copy_(2, cache_position, value_states)
except NotImplementedError:
self.keys[:, :, cache_position] = key_states
self.values[:, :, cache_position] = value_states
return self.keys, self.values
def reset(self):
if self.is_initialized:
self.keys_.zero_()
self.values_.zero_()
self.keys = self.keys_
self.values = self.values_
def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]:
"""Return the length and offset of the cache, used to generate the attention mask"""
kv_offset = 0
kv_length = self.max_cache_len
return kv_length, kv_offset
def get_seq_length(self) -> int:
"""Returns the sequence length of the cached states."""
# Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's
# limit the check to the first batch member and head dimension.
return (self.keys[0, 0].any(dim=-1)).sum() if self.is_initialized else 0
def get_max_cache_shape(self) -> int:
"""Return the maximum cache shape of the cache"""
return self.max_cache_lenthis is my |
Ah right, encoder-decoder ones are a bit different. Naming the kv with max batch size differently sounds good to me. Probably a bit more informative name would be better, it's easy to lose track when reading the code @i3hz can you push the code you have with all the updates. Also, in which cases you're getting a seg fault, in test files or in bench script? It is important to not compile a prefill stage in custom generation loop, or if we have to compile in advance then cache has to be early initialized. The lazy init function is known to fail when compiled |
|
The code is a bit messy but I'll change it later sorry |
Yeah probably
|
It failed on my end . Can you please look into the implementation and lmk if i messed something up? |
Yeah I forgot about this sorry . I split the test into eager prefill + compile decode and it now passes the test successfully . But it is failing the CI tests , which is something I'm working on |
|
I would really appreciate any pointers on why the CI tests are failing @zucchini-nlp |
|
The failures on qwen3 aren't related, that prob happended due to too many requests which caused failure to download image data. I will take a look at this PR around this week, have been quite busy with other tasks lately In the meantime, can you make sure that PR is ready, and has the final bench script, performance results in the description. Also, I think it's nice to allow users to initialize a cache with max batch size from |
|
Btw can you also do a benchmark with speed, based on the logs, it seems this creates an issue with cudagraphs, I am not sure though if this is critical: |
@i3hz There's performance regression with this change using torch.compile with you can reproduce it: import torch
from transformers import WhisperForConditionalGeneration, AutoProcessor
from transformers.cache_utils import StaticCache, EncoderDecoderCache
import math
from tqdm import tqdm
import numpy as np
device = 'cuda:0'
torch_dtype = torch.float16
model_id = "openai/whisper-large-v3-turbo"
model = WhisperForConditionalGeneration.from_pretrained(model_id, torch_dtype=torch_dtype, attn_implementation="sdpa", device_map=device)
model.generation_config.cache_implementation = "static"
@torch.no_grad()
def run_encoder_raw(model, labels, encoder_outputs, past_key_values, prefill: bool):
seq_length = labels.shape[-1]
if(prefill):
cache_position = torch.arange(seq_length, device=device)
else:
cache_position = torch.tensor([seq_length], device=device)
out_decoder = model.model.decoder(
labels,
encoder_hidden_states=encoder_outputs,
past_key_values = past_key_values,
cache_position=cache_position,
use_cache = True,
return_dict=True,
)
cur_token = model.proj_out(out_decoder.last_hidden_state[:,-1:]).argmax(axis=-1)
past_key_values = out_decoder.past_key_values
return cur_token, past_key_values
max_batch_size = 32
max_cache_len = 256
enc_len = 1500
decoder = model.model.decoder
def create_cache(max_batch_size):
# Cache
self_cache = StaticCache(
config=decoder.config,
max_batch_size=max_batch_size,
max_cache_len=max_cache_len,
device=device,
dtype=torch_dtype,
)
cross_cache = StaticCache(
config=decoder.config,
max_batch_size=max_batch_size,
max_cache_len=enc_len,
device=device,
dtype=torch_dtype,
)
return EncoderDecoderCache(self_cache, cross_cache)
max_batch_size = 64
past_key_values = create_cache(max_batch_size)
run_encoder = run_encoder_raw
for _ in range(2):
for current_bs in [max_batch_size]: #[8, 4, 2, 1]:
past_key_values.reset()
assert current_bs <= max_batch_size, "batch_size should be <= max_batch_size"
seq_length = 3
labels = torch.tensor([[50258, 50259, 50360]] * current_bs, device=device, dtype=torch.int64)
encoder_outputs = torch.randn([current_bs, enc_len, 1280], device=device, dtype=torch_dtype)
cur_token, past_key_values_out = run_encoder(model, labels, encoder_outputs, past_key_values, prefill=True)
# warm-uup
for _ in range(3):
cur_token, past_key_values_out = run_encoder(model, cur_token.clone(), encoder_outputs, past_key_values_out, prefill=False)
torch.cuda.synchronize()
for _ in tqdm(range(500)):
cur_token, past_key_values_out = run_encoder(model, cur_token.clone(), encoder_outputs, past_key_values_out, prefill=False)
torch.cuda.synchronize()
print(f"{current_bs} pass!")
run_encoder = torch.compile(run_encoder_raw, mode='reduce-overhead', fullgraph=True)
print('----- compiled run ---- ') |
Yeah you're right , it's most likely from the slicing operations . Is there any way to optimize it ? |
|
@i3hz it's weird, because that slicing is not doing any copy, I also tried My assumption is that the issue comes from It seems to me there's no clean way of doing this properly. One thing that could work is initializing the full cache outside and pass it to |
|
We can do one thing is that if current batch size is equal to max batch size we just return the whole thing without slicing . But then again its not really a fix for the performance reduction but it does gives us similar performance for your repro script @mobicham . The performance reduction is just related to the list slicing . I tried a basic script where I just compare 2 compiled functions one just returning the tensor , other slicing it and for 10k iterations it's like more than 2x slower . So what can we do? |
|
@i3hz yeah I am not sure, I asked on the GPU MODE Discord, let's see what the Torch folks say about this slicing op |
|
@i3hz I made it work, no perf regression and vram usage in check ! Feel free to add the other checks and the comment. transformers/src/transformers/cache_utils.py Line 1077 in 5b4d72c
right now I hard-coded it to 128 for testing, but we need to get it from kwargs: class StaticLayer(CacheLayerMixin):
is_compileable = True
is_sliding = False
def __init__(self, max_cache_len: int):
super().__init__()
self.max_cache_len = max_cache_len
self.max_batch_size = 128
def init_full_cache(self, key_states: torch.Tensor):
""" Init the full batch with max_batch_size """
_, self.num_heads, _, self.head_dim = key_states.shape
self.dtype, self.device = key_states.dtype, key_states.device
self.keys_full = torch.zeros(
(self.max_batch_size, self.num_heads, self.max_cache_len, self.head_dim),
dtype=self.dtype,
device=self.device,
)
self.values_full = torch.zeros(
(self.max_batch_size, self.num_heads, self.max_cache_len, self.head_dim),
dtype=self.dtype,
device=self.device,
)
def lazy_initialization(self, key_states: torch.Tensor):
""" Mark each chunk with torch static address """
self.init_full_cache(key_states)
self.keys_map = [None] * (self.max_batch_size + 1)
self.values_map = [None] * (self.max_batch_size + 1)
for batch_size in range(1, self.max_batch_size + 1):
self.keys_map[batch_size] = self.keys_full[:batch_size]
self.values_map[batch_size] = self.values_full[:batch_size]
if not is_torchdynamo_compiling():
torch._dynamo.mark_static_address(self.keys_map[batch_size])
torch._dynamo.mark_static_address(self.values_map[batch_size])
self.is_initialized = True
def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
cache_kwargs: Optional[dict[str, Any]] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
if not self.is_initialized:
self.lazy_initialization(key_states)
cache_position = cache_kwargs.get("cache_position") if cache_kwargs is not None else None
cache_position = (
cache_position if cache_position is not None else torch.arange(key_states.shape[-2], device=self.device)
)
batch_size = key_states.shape[0]
self.keys = self.keys_map[batch_size]
self.values = self.values_map[batch_size]
self._current_batch_size = batch_size
try:
self.keys.index_copy_(2, cache_position, key_states)
self.values.index_copy_(2, cache_position, value_states)
except NotImplementedError:
self.keys[:, :, cache_position] = key_states
self.values[:, :, cache_position] = value_states
return self.keys, self.values
def reset(self):
if self.is_initialized:
self.keys_full.zero_()
self.values_full.zero_()
self._current_batch_size = None
def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]:
kv_offset = 0
kv_length = self.max_cache_len
return kv_length, kv_offset
def get_seq_length(self) -> int:
return (self.keys_full[0, 0].any(dim=-1)).sum() if self.is_initialized else 0
def get_max_cache_shape(self) -> int:
return self.max_cache_len
|
I've passed in this branch :)
Yeah LGTM . I think it will add some latency during the first forward pass for massive batch sizes so I'd just like to confirm with @zucchini-nlp before pushing the changes . |
|
@i3hz I didn't see any particular increase in latency at cache creation, the extra stuff is the loop and that loop is not doing any copy, but I do agree we need confirmation from @zucchini-nlp |
|
Hey, sorry again for delay. Interesting that the slicing op is causing issues, I was expecting that the sliced keys will share the same static address The above workaround doesn't look bad to me, since we're still creating a single cache object. I wonder if marking it static across different batch sizes affects anything 🤔 |
Why is this test failing ? Is it because we're calling it keys_full instead of keys? Anyways that's the only test that's failing . |
|
@i3hz sorry for the late answer. I haven't run the test but I guess |
Yep that's the one . All good take your time (i'm also on break lol) |
|
@zucchini-nlp should i close this PR? |
|
Apologies for the delay, I just got back from a break. The issue in the test is that, by design, the cache needs have max_batch_size dim 0. What do you think? |
That does sound reasonable but I'd like to confirm with @zucchini-nlp before making that change . Does it pass the test after the change? |
|
Yeah, ofc, feel free to adjust tests as it is needed |
@mobicham The test doesn't pass after adding the <= or am I missing something ? |
|
Another thing I found is on main but on this branch Why isn't the cache being registered? |
|
If we move the slicing in def lazy_initialization(self, key_states: torch.Tensor):
"""
Lazy initialization of the keys and values tensors. This allows to get all properties (dtype, device,
num_heads in case of TP etc...) at runtime directly, which is extremely practical as it avoids moving
devices, dtypes etc later on for each `update` (which could break the static dynamo addresses as well).
If this is unwanted, one can call `early_initialization(...)` on the Cache directly, which will call this
function ahead-of-time (this is required for `torch.export` for example). Note that for `compile`, as we
internally don't compile the prefill, this is guaranteed to have been called already when compiling.
If compiling the prefill as well, e.g. calling `model.compile(...)` before `generate` with a static cache,
it is still supported in general, but without guarantees depending on the compilation options (e.g. cuda graphs,
i.e. `mode="reduce-overhead"` is known to fail). But it will in general work correctly, and prefill should
not be compiled anyway for performances!
"""
self.init_full_cache(key_states)
self.keys_map = [None] * (self.max_batch_size + 1)
self.values_map = [None] * (self.max_batch_size + 1)
for batch_size in range(1, self.max_batch_size + 1):
self.keys_map[batch_size] = self.keys_full[:batch_size]
self.values_map[batch_size] = self.values_full[:batch_size]
# Note: `mark_static_address` is used to tag the cache as a fixed data pointer, preventing compiled graph
# breaks when updating the cache. However, it is not supported when tracing the graph, so we skip it in this case.
# As prefill should never be compiled, this is not an issue and it will still be run (except when users compile
# prefill explicitly, but this should be avoided!)
if not is_torchdynamo_compiling():
torch._dynamo.mark_static_address(self.keys_map[batch_size])
torch._dynamo.mark_static_address(self.values_map[batch_size])
self.is_initialized = True
batch_size = key_states.shape[0]
# Slice to current batch size
self.keys = self.keys_map[batch_size]
self.values = self.values_map[batch_size]What do you think? I think this is due to |
import torch
from transformers import WhisperForConditionalGeneration, AutoProcessor
from transformers.cache_utils import StaticCache, EncoderDecoderCache
import math
from tqdm import tqdm
import numpy as np
device = 'cuda:0'
torch_dtype = torch.float16
model_id = "openai/whisper-large-v3-turbo"
model = WhisperForConditionalGeneration.from_pretrained(model_id, torch_dtype=torch_dtype, attn_implementation="sdpa", device_map=device)
model.generation_config.cache_implementation = "static"
@torch.no_grad()
def run_encoder_raw(model, labels, encoder_outputs, past_key_values, prefill: bool):
seq_length = labels.shape[-1]
if(prefill):
cache_position = torch.arange(seq_length, device=device)
else:
cache_position = torch.tensor([seq_length], device=device)
out_decoder = model.model.decoder(
labels,
encoder_hidden_states=encoder_outputs,
past_key_values = past_key_values,
cache_position=cache_position,
use_cache = True,
return_dict=True,
)
cur_token = model.proj_out(out_decoder.last_hidden_state[:,-1:]).argmax(axis=-1)
past_key_values = out_decoder.past_key_values
return cur_token, past_key_values
max_batch_size = 32
max_cache_len = 256
enc_len = 1500
decoder = model.model.decoder
def create_cache(max_batch_size):
# Cache
self_cache = StaticCache(
config=decoder.config,
max_batch_size=max_batch_size,
max_cache_len=max_cache_len,
device=device,
dtype=torch_dtype,
)
cross_cache = StaticCache(
config=decoder.config,
max_batch_size=max_batch_size,
max_cache_len=enc_len,
device=device,
dtype=torch_dtype,
)
return EncoderDecoderCache(self_cache, cross_cache)
max_batch_size = 64
past_key_values = create_cache(max_batch_size)
run_encoder = run_encoder_raw
for _ in range(2):
for current_bs in [max_batch_size]: #[8, 4, 2, 1]:
past_key_values.reset()
assert current_bs <= max_batch_size, "batch_size should be <= max_batch_size"
seq_length = 3
labels = torch.tensor([[50258, 50259, 50360]] * current_bs, device=device, dtype=torch.int64)
encoder_outputs = torch.randn([current_bs, enc_len, 1280], device=device, dtype=torch_dtype)
cur_token, past_key_values_out = run_encoder(model, labels, encoder_outputs, past_key_values, prefill=True)
# warm-uup
for _ in range(3):
cur_token, past_key_values_out = run_encoder(model, cur_token.clone(), encoder_outputs, past_key_values_out, prefill=False)
torch.cuda.synchronize()
for _ in tqdm(range(500)):
cur_token, past_key_values_out = run_encoder(model, cur_token.clone(), encoder_outputs, past_key_values_out, prefill=False)
torch.cuda.synchronize()
print(f"{current_bs} pass!")
run_encoder = torch.compile(run_encoder_raw, mode='reduce-overhead', fullgraph=True)
print('----- compiled run ---- ')I don't see any big difference in performance after using this either. |
|
Thanks @i3hz that's good new, I just tested it and it works 👍 , I also don't see a difference in perf, so should be good to go. You should test with import torch
from transformers import WhisperForConditionalGeneration, AutoProcessor
from transformers.cache_utils import StaticCache, EncoderDecoderCache
import math
from tqdm import tqdm
import numpy as np
device = 'cuda:0'
torch_dtype = torch.float16
model_id = "openai/whisper-large-v3-turbo"
model = WhisperForConditionalGeneration.from_pretrained(model_id, torch_dtype=torch_dtype, attn_implementation="sdpa", device_map=device)
model.generation_config.cache_implementation = "static"
@torch.no_grad()
def run_encoder_raw(model, labels, encoder_outputs, past_key_values, prefill: bool):
seq_length = labels.shape[-1]
if(prefill):
cache_position = torch.arange(seq_length, device=device)
else:
cache_position = torch.tensor([seq_length], device=device)
out_decoder = model.model.decoder(
labels,
encoder_hidden_states=encoder_outputs,
past_key_values = past_key_values,
cache_position=cache_position,
use_cache = True,
return_dict=True,
)
cur_token = model.proj_out(out_decoder.last_hidden_state[:,-1:]).argmax(axis=-1)
past_key_values = out_decoder.past_key_values
return cur_token, past_key_values
max_batch_size = 32
max_cache_len = 256
enc_len = 1500
decoder = model.model.decoder
##################################################################################################
#Monkey patch
from typing import Optional, Any
from transformers import cache_utils
from transformers.cache_utils import is_torchdynamo_compiling, CacheLayerMixin
class StaticLayer(CacheLayerMixin):
is_compileable = True
is_sliding = False
def __init__(self, max_cache_len: int):
super().__init__()
self.max_cache_len = max_cache_len
self.max_batch_size = 128
def init_full_cache(self, key_states: torch.Tensor):
""" Init the full batch with max_batch_size """
_, self.num_heads, _, self.head_dim = key_states.shape
self.dtype, self.device = key_states.dtype, key_states.device
self.keys_full = torch.zeros(
(self.max_batch_size, self.num_heads, self.max_cache_len, self.head_dim),
dtype=self.dtype,
device=self.device,
)
self.values_full = torch.zeros(
(self.max_batch_size, self.num_heads, self.max_cache_len, self.head_dim),
dtype=self.dtype,
device=self.device,
)
def lazy_initialization(self, key_states: torch.Tensor):
self.init_full_cache(key_states)
self.keys_map = [None] * (self.max_batch_size + 1)
self.values_map = [None] * (self.max_batch_size + 1)
for batch_size in range(1, self.max_batch_size + 1):
self.keys_map[batch_size] = self.keys_full[:batch_size]
self.values_map[batch_size] = self.values_full[:batch_size]
if not is_torchdynamo_compiling():
torch._dynamo.mark_static_address(self.keys_map[batch_size])
torch._dynamo.mark_static_address(self.values_map[batch_size])
self.is_initialized = True
batch_size = key_states.shape[0]
# Slice to current batch size
self.keys = self.keys_map[batch_size]
self.values = self.values_map[batch_size]
def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
cache_kwargs: Optional[dict[str, Any]] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
if not self.is_initialized:
self.lazy_initialization(key_states)
cache_position = cache_kwargs.get("cache_position") if cache_kwargs is not None else None
cache_position = (
cache_position if cache_position is not None else torch.arange(key_states.shape[-2], device=self.device)
)
batch_size = key_states.shape[0]
self.keys = self.keys_map[batch_size]
self.values = self.values_map[batch_size]
self._current_batch_size = batch_size
try:
self.keys.index_copy_(2, cache_position, key_states)
self.values.index_copy_(2, cache_position, value_states)
except NotImplementedError:
self.keys[:, :, cache_position] = key_states
self.values[:, :, cache_position] = value_states
return self.keys, self.values
def reset(self):
if self.is_initialized:
self.keys_full.zero_()
self.values_full.zero_()
self._current_batch_size = None
def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]:
kv_offset = 0
kv_length = self.max_cache_len
return kv_length, kv_offset
def get_seq_length(self) -> int:
return (self.keys_full[0, 0].any(dim=-1)).sum() if self.is_initialized else 0
def get_max_cache_shape(self) -> int:
return self.max_cache_len
cache_utils.StaticLayer = StaticLayer
##################################################################################################
def create_cache(max_batch_size):
# Cache
self_cache = StaticCache(
config=decoder.config,
max_batch_size=max_batch_size,
max_cache_len=max_cache_len,
device=device,
dtype=torch_dtype,
)
cross_cache = StaticCache(
config=decoder.config,
max_batch_size=max_batch_size,
max_cache_len=enc_len,
device=device,
dtype=torch_dtype,
)
return EncoderDecoderCache(self_cache, cross_cache)
max_batch_size = 64
past_key_values = create_cache(max_batch_size)
run_encoder = run_encoder_raw
for _ in range(2):
for current_bs in [1]: #[8, 4, 2, 1]:
past_key_values.reset()
assert current_bs <= max_batch_size, "batch_size should be <= max_batch_size"
seq_length = 3
labels = torch.tensor([[50258, 50259, 50360]] * current_bs, device=device, dtype=torch.int64)
encoder_outputs = torch.randn([current_bs, enc_len, 1280], device=device, dtype=torch_dtype)
cur_token, past_key_values_out = run_encoder(model, labels, encoder_outputs, past_key_values, prefill=True)
# warm-uup
for _ in range(3):
cur_token, past_key_values_out = run_encoder(model, cur_token.clone(), encoder_outputs, past_key_values_out, prefill=False)
torch.cuda.synchronize()
for _ in tqdm(range(500)):
cur_token, past_key_values_out = run_encoder(model, cur_token.clone(), encoder_outputs, past_key_values_out, prefill=False)
torch.cuda.synchronize()
print(f"{current_bs} pass!")
run_encoder = torch.compile(run_encoder_raw, mode='reduce-overhead', fullgraph=True)
print('----- compiled run ---- ')Output: |
|
Thanks a lot for your feedback @mobicham. What should I add as a test,should it just be like a similar version of the benchmarking script ? |
What does this PR do?
Fixes #42454
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@zucchini-nlp @Rocketknight1 @mobicham
Benchmarking script -