Skip to content

[Bug] Positional Embedding Index Out-of-Bounds in AuraFlow #12656

@ikuto32

Description

@ikuto32

Describe the bug

When training the AuraFlow (Pony V7) lora model (via SimpleTuner), the positional embedding selection function
pe_selection_index_based_on_dim() can generate negative or out-of-range indices for the positional embedding table.

This causes:

CUDA error: device-side assert triggered
vectorized_gather_kernel: Assertion `ind >=0 && ind < ind_dim_size` failed.

Root cause: the code computes a centered crop of the PE grid, but when the input latent grid
(h // patch_size, w // patch_size) is larger than the pretrained positional embedding grid,
the computed indices become invalid (negative start, or exceeding max grid index).

Reproduction

import torch

# Simulated model parameters
patch_size = 2
pos_embed_max_size = 96 * 96  # 9216 positions
pos_embed = torch.zeros(1, pos_embed_max_size, 3072)

def pe_selection_index_based_on_dim(h, w):
    h_p, w_p = h // patch_size, w // patch_size
    h_max = w_max = int(pos_embed_max_size ** 0.5)

    starth = h_max // 2 - h_p // 2
    startw = w_max // 2 - w_p // 2

    rows = torch.arange(starth, starth + h_p)
    cols = torch.arange(startw, startw + w_p)
    row_idx, col_idx = torch.meshgrid(rows, cols, indexing="ij")

    return (row_idx * w_max + col_idx).flatten()

# Example that produces invalid indices (real case)
h, w = 196, 196  # latent spatial dims
idx = pe_selection_index_based_on_dim(h, w)

print(idx.min().item(), idx.max().item())   # → negative & >9215
assert (idx >= 0).all() and (idx < pos_embed_max_size).all()

Observed Output

-97 9312
AssertionError

This reproduces the exact index state that triggers the CUDA vectorized_gather_kernel failure during training.

Logs

System Info

  • 🤗 Diffusers version: 0.35.2
  • Platform: Linux-6.6.87.2-microsoft-standard-WSL2-x86_64-with-glibc2.39
  • Running on Google Colab?: No
  • Python version: 3.12.3
  • PyTorch version (GPU?): 2.9.0+cu130 (True)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Huggingface_hub version: 0.36.0
  • Transformers version: 4.57.1
  • Accelerate version: 1.11.0
  • PEFT version: 0.17.1
  • Bitsandbytes version: 0.48.2
  • Safetensors version: 0.6.2
  • xFormers version: not installed
  • Accelerator: NVIDIA GeForce RTX 4090, 24564 MiB
  • Using GPU in script?: yes
  • Using distributed or parallel set-up in script?: no

Who can help?

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions