Skip to content

RuntimeError: size mismatch when loading pretrained weights with standard flash-attn #117

@XyeaOvO

Description

@XyeaOvO

Hi Infinity Team,

First of all, thank you for your incredible work on this project! I'm currently trying to run inference with the provided pretrained weights (infinity_2b_reg.pth) but have encountered a RuntimeError due to a parameter shape mismatch when loading the model's state dictionary.

It seems the issue stems from how the model handles the customized_flash_attn flag and its effect on the internal architecture, specifically the shape of the scale_mul_1H11 parameter in SelfAttention blocks.

Environment

  • PyTorch Version: 2.5.1+cu118
  • CUDA Version: 11.8
  • GPU: NVIDIA GeForce RTX 3090
  • flash-attn Version: Installed via pip install flash-attn --no-build-isolation. A diagnostic script confirms it is installed correctly and the CUDA kernel is executable.

The Problem

The root cause appears to be in the Infinity model's __init__ method, which checks for a custom-patched version of flash-attn:

# In infinity/models/infinity.py
customized_kernel_installed = any('Infinity' in arg_name for arg_name in flash_attn_func.__code__.co_varnames)
self.customized_flash_attn = customized_flash_attn and customized_kernel_installed

When using the standard flash-attn package from PyPI, flash_attn_func does not have any arguments containing "Infinity", so customized_kernel_installed becomes False.

This self.customized_flash_attn flag is then passed to the SelfAttention block in infinity/models/basic.py, which determines the shape of a key parameter:

# In infinity/models/basic.py
class SelfAttention(nn.Module):
    def __init__(self, ..., customized_flash_attn=True, ...):
        # ...
        self.using_flash = customized_flash_attn
        if self.cos_attn:
            # The shape of this parameter depends on `self.using_flash`
            size = (1, 1, self.num_heads, 1) if self.using_flash else (1, self.num_heads, 1, 1)
            self.scale_mul_1H11 = nn.Parameter(...)

When I run the inference script with --flash=1 (setting customized_flash_attn=True at initialization), the following happens:

  1. The custom kernel check fails, so self.customized_flash_attn is reset to False inside the Infinity class.
  2. SelfAttention blocks are instantiated with customized_flash_attn=False, so self.using_flash is False.
  3. This sets the shape of scale_mul_1H11 to (1, num_heads, 1, 1), which for the 2B model is (1, 16, 1, 1).

However, when I try to bypass the check (by forcing customized_kernel_installed = True), self.using_flash becomes True, and the shape of scale_mul_1H11 becomes (1, 1, num_heads, 1) or (1, 1, 16, 1). This leads to the following RuntimeError:

RuntimeError: Error(s) in loading state_dict for Infinity:
        size mismatch for block_chunks.0.module.0.sa.scale_mul_1H11: copying a param with shape torch.Size([1, 16, 1, 1]) from checkpoint, the shape in current model is torch.Size([1, 1, 16, 1]).
        ... (and so on for all blocks)

This error message implies that the pretrained checkpoint was saved with a model where using_flash was False, corresponding to a scale_mul_1H11 shape of (1, 16, 1, 1).

Proposed Solution / Question

The immediate workaround is to run the inference script without the --flash=1 argument. This ensures customized_flash_attn is False from the beginning, aligning the model architecture with the checkpoint weights. The model then correctly falls back to using torch.nn.functional.scaled_dot_product_attention, which can still leverage the installed flash-attn package under the hood.

However, this raises a question for the benefit of the community:

  1. Could you clarify the intended setup for using FlashAttention with this project?
  2. Is there a specific patch or a forked repository for flash-attn that we should be using to enable the "customized" version?
  3. If so, could you please add instructions for compiling this custom version to the README.md?

This would greatly help users to reproduce the intended high-performance setup.

Thank you for your time and support!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions