Skip to content

Conversation

urlesistiana
Copy link

@urlesistiana urlesistiana commented Sep 30, 2025

This PR has two features to speed up Lumina 2 training:

checkpointing + torch.compile sub modules

Currently, when using gradient checkpointing, torch.compile will skip all frames (modules) inside the checkpointed models. Those sub modules have to be compiled first. (I don't know torch.compile very well, but it seems a expected behavior).

I added a env "SDSCRIPTS_SELECTIVE_TORCH_COMPILE", set it to 1 will compile those sub modules.

In my setup, training a rank 16 LoRA, with gradient checkpointing, batch size 6 and resolution 1024. This reduces train time from 9.2s/it -> 6.6/it. (1.5s/img ->1.1s/img).

Only compile those core modules also signifyingly reduce compile time, from ~220s (use "--torch_compile") to ~5s.

torch.compile + Memory Budget API

In pytorch 2.4 there is a new feature called "Memory Budget API", which not only automatically does checkpointing but also only recompute cheap operations. So it is faster than traditional checkpointing method.

ref: https://pytorch.org/blog/activation-checkpointing-techniques/

I added a env "SDSCRIPTS_TORCH_COMPILE_ACTIVATION_MEMORY_BUDGET" to set the budget .

In my setup, training a rank 16 LoRA, with batch size 2, resolution 1024 and set the budget to "0.5", without "--gradient_checkpointing" but with "SDSCRIPTS_SELECTIVE_TORCH_COMPILE".

This reduces train time from 3.03/it -> 1.75/it. (1.5s/img ->0.88s/img), 70% faster.

Note: Training Lumina 2 without gradient checkpointing will OOM (>24G VRAM) even with batch size 1.


Tested on latest sd3 branch Lumina 2 training. Torch 2.9 nightly (9/29/2025), cuda 13.0, python 3.12, Nvidia 4090.

I'm not sure how to set those arguments properly, and I only tested for Lumina 2, so I set them via env for now, for minimum changes. But "torch.compile + Memory Budget API" seems so powerful and could apply to any model by just changing a global settings.

Open as draft for suggestions.

@urlesistiana
Copy link
Author

added "--activation_memory_budget" argument. In theory models that can use torch.compile can benefit from this setting.

@urlesistiana urlesistiana marked this pull request as ready for review September 30, 2025 14:08
don't compile funcs with complex ops

simplify FeedForward to avoid "cache line invalidated" error
@kohya-ss
Copy link
Owner

kohya-ss commented Oct 1, 2025

Thank you, this looks very promising. Please give me some time to investigate gradient checkpointing, torch.compile, and the Memory Budget API.
Did you test it on a Linux environment? Please let me know if you know about compatibility on Windows.

@urlesistiana
Copy link
Author

Yes, I only tested on Linux. Can't test on Windows because I don't have such setup.

But I guess it should be ok on Windows.

We have the "--torch_compile" option already, which compile the whole model. For lumina2, if existing --torch_compile works, then "SDSCRIPTS_SELECTIVE_TORCH_COMPILE" should also work. It only compiles some sub models. No breaking changes.

For the global --activation_memory_budget option, if it works, then great. If not, just don't use it and it won't affect anything. No breaking changes.

One thing I do really concert, this "Memory Budget API" does not have any official documents. Only this blog mentioned this is a "experimental feature". But I can't find it in release note or list. Don't know if it is a "Beta" or "Prototype" or "Stable" feature. If it is not stable, maybe a env setting would be better.

@FurkanGozukara
Copy link

how --torch_compile working during training? or this is just inference?

@urlesistiana
Copy link
Author

@FurkanGozukara Same as inference, catch and build graphs. Fuse them. Speed things up.

@FurkanGozukara
Copy link

@urlesistiana looks excellent

so what are the negatives? gradient checkpointing wont be used anymore? or block swap?

@urlesistiana
Copy link
Author

No negatives. If it works, then great, free speed up. If it doesn't, because the model has unsupported strange code or operations, just don't use it.

not sure if block swap works with torch.compile, not tested.

But it can replace traditional gradient checkpointing with a smarter one, can utilize more VRAM for more speed. That's positive.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants