Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(comm/attn_offload.py): support selective ckpt and cpu offload #383

Open
wants to merge 5 commits into
base: develop
Choose a base branch
from

Conversation

huangting4201
Copy link
Collaborator

@huangting4201 huangting4201 commented Dec 3, 2024

Motivation

Support selective checkpoint and cpu offload asynchronously to improve performance.

  1. When enabling checkpointing, if selective_checkpoint is set to True, for layers that are recomputed, storing the intermediate activations of the attention part allows the attention part to be skipped during recomputation, thereby enhancing training performance.
  2. If selective_checkpoint_offload is further set to True, the intermediate activations of the attention part for the recomputed layers will be asynchronously offloaded to the CPU to save GPU memory.

However, it should be noted that current testing has revealed that when selective_checkpoint_offload is set to True, the DtoH and HtoD operations compete with allgather and other communications for bandwidth, leading to increased allgather communication times and a consequent decline in overall performance. Therefore, it is advisable to avoid enabling selective_checkpoint_offload when it is not necessary.

Modification

internlm/core/parallel/comm/attn_offload.py: AttnOffloadManager, a manager for attention output CPU offloading and GPU prefetch loading.

Use cases (Optional)

example config:

selective_checkpoint = True
selective_checkpoint_offload = False
model = dict(
    num_chunks=1,  # if num_chunks > 1, interleaved pipeline scheduler is used.
    checkpoint=1,  # The proportion of layers for activation aheckpointing, the optional value are True/False/[0-1]
    ......
)

note:should be used with isp, and only GQA is supported now

loss accuracy checking
img_v3_02gj_77259706-f94c-466c-a451-ee88aad1de7g

img_v3_02gj_49a5bb6d-af96-4af9-9bc3-49493ad86fdg

Checklist

Before PR:

  • Pre-commit or other linting tools are used to fix the potential lint issues.
  • Bug fixes are fully covered by unit tests, the case that causes the bug should be added in the unit tests.
  • The modification is covered by complete unit tests. If not, please add more unit test to ensure the correctness.
  • The documentation has been modified accordingly, like docstring or example tutorials.

After PR:

  • If the modification has potential influence on downstream or other related projects, this PR should be tested with those projects.
  • CLA has been signed and all committers have signed the CLA in this PR.

@huangting4201
Copy link
Collaborator Author

依赖pr #381

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants