-
Notifications
You must be signed in to change notification settings - Fork 6.5k
Open
Labels
bugSomething isn't workingSomething isn't working
Description
Describe the bug
When training the AuraFlow (Pony V7) lora model (via SimpleTuner), the positional embedding selection function
pe_selection_index_based_on_dim() can generate negative or out-of-range indices for the positional embedding table.
This causes:
CUDA error: device-side assert triggered
vectorized_gather_kernel: Assertion `ind >=0 && ind < ind_dim_size` failed.
Root cause: the code computes a centered crop of the PE grid, but when the input latent grid
(h // patch_size, w // patch_size) is larger than the pretrained positional embedding grid,
the computed indices become invalid (negative start, or exceeding max grid index).
Reproduction
import torch
# Simulated model parameters
patch_size = 2
pos_embed_max_size = 96 * 96 # 9216 positions
pos_embed = torch.zeros(1, pos_embed_max_size, 3072)
def pe_selection_index_based_on_dim(h, w):
h_p, w_p = h // patch_size, w // patch_size
h_max = w_max = int(pos_embed_max_size ** 0.5)
starth = h_max // 2 - h_p // 2
startw = w_max // 2 - w_p // 2
rows = torch.arange(starth, starth + h_p)
cols = torch.arange(startw, startw + w_p)
row_idx, col_idx = torch.meshgrid(rows, cols, indexing="ij")
return (row_idx * w_max + col_idx).flatten()
# Example that produces invalid indices (real case)
h, w = 196, 196 # latent spatial dims
idx = pe_selection_index_based_on_dim(h, w)
print(idx.min().item(), idx.max().item()) # → negative & >9215
assert (idx >= 0).all() and (idx < pos_embed_max_size).all()Observed Output
-97 9312
AssertionError
This reproduces the exact index state that triggers the CUDA vectorized_gather_kernel failure during training.
Logs
System Info
- 🤗 Diffusers version: 0.35.2
- Platform: Linux-6.6.87.2-microsoft-standard-WSL2-x86_64-with-glibc2.39
- Running on Google Colab?: No
- Python version: 3.12.3
- PyTorch version (GPU?): 2.9.0+cu130 (True)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Huggingface_hub version: 0.36.0
- Transformers version: 4.57.1
- Accelerate version: 1.11.0
- PEFT version: 0.17.1
- Bitsandbytes version: 0.48.2
- Safetensors version: 0.6.2
- xFormers version: not installed
- Accelerator: NVIDIA GeForce RTX 4090, 24564 MiB
- Using GPU in script?: yes
- Using distributed or parallel set-up in script?: no
Who can help?
No response
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working