-
Notifications
You must be signed in to change notification settings - Fork 75
Description
Hi Infinity Team,
First of all, thank you for your incredible work on this project! I'm currently trying to run inference with the provided pretrained weights (infinity_2b_reg.pth
) but have encountered a RuntimeError
due to a parameter shape mismatch when loading the model's state dictionary.
It seems the issue stems from how the model handles the customized_flash_attn
flag and its effect on the internal architecture, specifically the shape of the scale_mul_1H11
parameter in SelfAttention
blocks.
Environment
- PyTorch Version: 2.5.1+cu118
- CUDA Version: 11.8
- GPU: NVIDIA GeForce RTX 3090
flash-attn
Version: Installed viapip install flash-attn --no-build-isolation
. A diagnostic script confirms it is installed correctly and the CUDA kernel is executable.
The Problem
The root cause appears to be in the Infinity
model's __init__
method, which checks for a custom-patched version of flash-attn
:
# In infinity/models/infinity.py
customized_kernel_installed = any('Infinity' in arg_name for arg_name in flash_attn_func.__code__.co_varnames)
self.customized_flash_attn = customized_flash_attn and customized_kernel_installed
When using the standard flash-attn
package from PyPI, flash_attn_func
does not have any arguments containing "Infinity", so customized_kernel_installed
becomes False
.
This self.customized_flash_attn
flag is then passed to the SelfAttention
block in infinity/models/basic.py
, which determines the shape of a key parameter:
# In infinity/models/basic.py
class SelfAttention(nn.Module):
def __init__(self, ..., customized_flash_attn=True, ...):
# ...
self.using_flash = customized_flash_attn
if self.cos_attn:
# The shape of this parameter depends on `self.using_flash`
size = (1, 1, self.num_heads, 1) if self.using_flash else (1, self.num_heads, 1, 1)
self.scale_mul_1H11 = nn.Parameter(...)
When I run the inference script with --flash=1
(setting customized_flash_attn=True
at initialization), the following happens:
- The custom kernel check fails, so
self.customized_flash_attn
is reset toFalse
inside theInfinity
class. SelfAttention
blocks are instantiated withcustomized_flash_attn=False
, soself.using_flash
isFalse
.- This sets the shape of
scale_mul_1H11
to(1, num_heads, 1, 1)
, which for the 2B model is(1, 16, 1, 1)
.
However, when I try to bypass the check (by forcing customized_kernel_installed = True
), self.using_flash
becomes True
, and the shape of scale_mul_1H11
becomes (1, 1, num_heads, 1)
or (1, 1, 16, 1)
. This leads to the following RuntimeError
:
RuntimeError: Error(s) in loading state_dict for Infinity:
size mismatch for block_chunks.0.module.0.sa.scale_mul_1H11: copying a param with shape torch.Size([1, 16, 1, 1]) from checkpoint, the shape in current model is torch.Size([1, 1, 16, 1]).
... (and so on for all blocks)
This error message implies that the pretrained checkpoint was saved with a model where using_flash
was False
, corresponding to a scale_mul_1H11
shape of (1, 16, 1, 1)
.
Proposed Solution / Question
The immediate workaround is to run the inference script without the --flash=1
argument. This ensures customized_flash_attn
is False
from the beginning, aligning the model architecture with the checkpoint weights. The model then correctly falls back to using torch.nn.functional.scaled_dot_product_attention
, which can still leverage the installed flash-attn
package under the hood.
However, this raises a question for the benefit of the community:
- Could you clarify the intended setup for using FlashAttention with this project?
- Is there a specific patch or a forked repository for
flash-attn
that we should be using to enable the "customized" version? - If so, could you please add instructions for compiling this custom version to the
README.md
?
This would greatly help users to reproduce the intended high-performance setup.
Thank you for your time and support!