-
-
Notifications
You must be signed in to change notification settings - Fork 1.1k
feat: torch compile within checkpointing and activation memory budget for Lumina 2 #2217
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: sd3
Are you sure you want to change the base?
Conversation
remove SDSCRIPTS_TORCH_COMPILE_ACTIVATION_MEMORY_BUDGET env
added "--activation_memory_budget" argument. In theory models that can use torch.compile can benefit from this setting. |
don't compile funcs with complex ops simplify FeedForward to avoid "cache line invalidated" error
Thank you, this looks very promising. Please give me some time to investigate gradient checkpointing, torch.compile, and the Memory Budget API. |
Yes, I only tested on Linux. Can't test on Windows because I don't have such setup. But I guess it should be ok on Windows. We have the "--torch_compile" option already, which compile the whole model. For lumina2, if existing --torch_compile works, then "SDSCRIPTS_SELECTIVE_TORCH_COMPILE" should also work. It only compiles some sub models. No breaking changes. For the global --activation_memory_budget option, if it works, then great. If not, just don't use it and it won't affect anything. No breaking changes. One thing I do really concert, this "Memory Budget API" does not have any official documents. Only this blog mentioned this is a "experimental feature". But I can't find it in release note or list. Don't know if it is a "Beta" or "Prototype" or "Stable" feature. If it is not stable, maybe a env setting would be better. |
how --torch_compile working during training? or this is just inference? |
@FurkanGozukara Same as inference, catch and build graphs. Fuse them. Speed things up. |
@urlesistiana looks excellent so what are the negatives? gradient checkpointing wont be used anymore? or block swap? |
No negatives. If it works, then great, free speed up. If it doesn't, because the model has unsupported strange code or operations, just don't use it. not sure if block swap works with torch.compile, not tested. But it can replace traditional gradient checkpointing with a smarter one, can utilize more VRAM for more speed. That's positive. |
This PR has two features to speed up Lumina 2 training:
checkpointing + torch.compile sub modules
Currently, when using gradient checkpointing, torch.compile will skip all frames (modules) inside the checkpointed models. Those sub modules have to be compiled first. (I don't know torch.compile very well, but it seems a expected behavior).
I added a env "SDSCRIPTS_SELECTIVE_TORCH_COMPILE", set it to 1 will compile those sub modules.
In my setup, training a rank 16 LoRA, with gradient checkpointing, batch size 6 and resolution 1024. This reduces train time from 9.2s/it -> 6.6/it. (1.5s/img ->1.1s/img).
Only compile those core modules also signifyingly reduce compile time, from ~220s (use "--torch_compile") to ~5s.
torch.compile + Memory Budget API
In pytorch 2.4 there is a new feature called "Memory Budget API", which not only automatically does checkpointing but also only recompute cheap operations. So it is faster than traditional checkpointing method.
ref: https://pytorch.org/blog/activation-checkpointing-techniques/
I added a env "SDSCRIPTS_TORCH_COMPILE_ACTIVATION_MEMORY_BUDGET" to set the budget .
In my setup, training a rank 16 LoRA, with batch size 2, resolution 1024 and set the budget to "0.5", without "--gradient_checkpointing" but with "SDSCRIPTS_SELECTIVE_TORCH_COMPILE".
This reduces train time from 3.03/it -> 1.75/it. (1.5s/img ->0.88s/img), 70% faster.
Note: Training Lumina 2 without gradient checkpointing will OOM (>24G VRAM) even with batch size 1.
Tested on latest sd3 branch Lumina 2 training. Torch 2.9 nightly (9/29/2025), cuda 13.0, python 3.12, Nvidia 4090.
I'm not sure how to set those arguments properly, and I only tested for Lumina 2, so I set them via env for now, for minimum changes. But "torch.compile + Memory Budget API" seems so powerful and could apply to any model by just changing a global settings.
Open as draft for suggestions.