feat(comm/attn_offload.py): support selective ckpt and cpu offload #383
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Motivation
Support selective checkpoint and cpu offload asynchronously to improve performance.
selective_checkpoint
is set to True, for layers that are recomputed, storing the intermediate activations of the attention part allows the attention part to be skipped during recomputation, thereby enhancing training performance.selective_checkpoint_offload
is further set to True, the intermediate activations of the attention part for the recomputed layers will be asynchronously offloaded to the CPU to save GPU memory.However, it should be noted that current testing has revealed that when
selective_checkpoint_offload
is set to True, theDtoH
andHtoD
operations compete withallgather
and other communications for bandwidth, leading to increasedallgather
communication times and a consequent decline in overall performance. Therefore, it is advisable to avoid enabling selective_checkpoint_offload when it is not necessary.Modification
internlm/core/parallel/comm/attn_offload.py
: AttnOffloadManager, a manager for attention output CPU offloading and GPU prefetch loading.Use cases (Optional)
example config:
note:should be used with isp, and only GQA is supported now
loss accuracy checking
Checklist
Before PR:
After PR: