diff --git a/internlm/core/parallel/comm/__init__.py b/internlm/core/parallel/comm/__init__.py index e69de29bb..be170f286 100644 --- a/internlm/core/parallel/comm/__init__.py +++ b/internlm/core/parallel/comm/__init__.py @@ -0,0 +1,3 @@ +from .attn_offload import get_offload_manager, initialize_offload_manager + +__all__ = ["initialize_offload_manager", "get_offload_manager"] diff --git a/internlm/core/parallel/comm/attn_offload.py b/internlm/core/parallel/comm/attn_offload.py new file mode 100644 index 000000000..da23f3ae8 --- /dev/null +++ b/internlm/core/parallel/comm/attn_offload.py @@ -0,0 +1,127 @@ +import torch + +from internlm.utils.common import get_current_device + +global_attn_offload = None + + +class AttnOffloadManager: + """ + A manager for attention output CPU offloading and GPU prefetch loading. + """ + + def __init__(self, enable_cpu_offload: bool = False) -> None: + # cpu offload overlapping + self.cpu_offload = enable_cpu_offload + # layer id mapping to flash attn output + self.fa_output_mapping = {} + self.fa_stream = torch.cuda.Stream() + self.d2h_final_event = torch.cuda.Event() + self.h2d_final_event = torch.cuda.Event() + # prepare for tensor buffer + self.tensor_id_to_tensor_bufs = {} + + def get_tensor_buf_for_offloaded_tensor(self, tensor, layer_id, tensor_id): + """Get tensor buffer for offloaded tensor.""" + layer_id = layer_id % 2 + if layer_id not in self.tensor_id_to_tensor_bufs: + self.tensor_id_to_tensor_bufs[layer_id] = {} + + if tensor_id not in self.tensor_id_to_tensor_bufs[layer_id]: + allocate_new_buf = True + else: + tensor_buf = self.tensor_id_to_tensor_bufs[layer_id][tensor_id] + allocate_new_buf = tensor_buf.size() == tensor.size() and tensor_buf.dtype == tensor.dtype + + if allocate_new_buf: + # supposed to only execute once + buffer = torch.empty( + tensor.size(), + dtype=tensor.dtype, + layout=tensor.layout, + device=tensor.device, + ) + + self.tensor_id_to_tensor_bufs[layer_id][tensor_id] = buffer + + return self.tensor_id_to_tensor_bufs[layer_id][tensor_id] + + def insert_fa_output_with_layer(self, layer_idx, output): + assert layer_idx not in self.fa_output_mapping + if self.cpu_offload is False: + self.fa_output_mapping[layer_idx] = output + return + + tensors = [] + for tensor_id, tensor in enumerate(output): + if tensor is None: + tensors.append(None) + continue + tensor_buf = self.get_tensor_buf_for_offloaded_tensor(tensor, layer_idx, tensor_id) + tensor_buf.copy_(tensor) + tensors.append(tensor_buf) + self.fa_output_mapping[layer_idx] = tensors + + def get_fa_output_with_layer(self, layer_idx): + assert layer_idx in self.fa_output_mapping + return self.fa_output_mapping.pop(layer_idx) + + def offload_fa_output_with_layer(self, layer_idx): + assert layer_idx in self.fa_output_mapping + + self.fa_stream.wait_stream(torch.cuda.current_stream()) + self.fa_stream.wait_event(self.d2h_final_event) + + with torch.cuda.stream(self.fa_stream): + _gpu_tensors = self.fa_output_mapping.pop(layer_idx) + _cpu_tensors = [] + for _tensor in _gpu_tensors: + if _tensor is None: + _cpu_tensors.append(_tensor) + continue + + _cpu_backup = torch.empty( + _tensor.size(), + dtype=_tensor.dtype, + layout=_tensor.layout, + device="cpu", + pin_memory=True, + ) + _cpu_backup.copy_(_tensor, non_blocking=True) + _cpu_tensors.append(_cpu_backup) + + # _cpu_tensors.append(_tensor.to("cpu", non_blocking=False)) + + self.fa_output_mapping[layer_idx] = _cpu_tensors + + self.fa_stream.record_event(self.d2h_final_event) + + def preload_fa_output_with_layer(self, layer_idx): + assert layer_idx in self.fa_output_mapping + + self.fa_stream.wait_stream(torch.cuda.current_stream()) + self.fa_stream.wait_event(self.h2d_final_event) + + # Important: get device before with stream, in stream get device is error + _device = get_current_device() + with torch.cuda.stream(self.fa_stream): + _cpu_tensors = self.fa_output_mapping.pop(layer_idx) + self.fa_output_mapping[layer_idx] = [ + _tensor.to(device=_device, non_blocking=True) if _tensor is not None else _tensor + for _tensor in _cpu_tensors + ] + + self.fa_stream.record_event(self.h2d_final_event) + + +def initialize_offload_manager(enable_cpu_offload: bool = False): + global global_attn_offload + if global_attn_offload is None: + global_attn_offload = AttnOffloadManager(enable_cpu_offload) + + return global_attn_offload + + +def get_offload_manager(): + assert global_attn_offload is not None + return global_attn_offload diff --git a/internlm/core/parallel/comm/isp.py b/internlm/core/parallel/comm/isp.py index 7e722c2fe..24677c090 100644 --- a/internlm/core/parallel/comm/isp.py +++ b/internlm/core/parallel/comm/isp.py @@ -37,6 +37,8 @@ params_dispatch_with_condition, ) +from .attn_offload import get_offload_manager + # not really useful, only for code hint. class WPCommunicator(ABC): @@ -306,6 +308,7 @@ def __init__( overlap: bool = False, process_group: dist.ProcessGroup = None, is_moe: bool = False, + selective_ckpt_offload: bool = False, ) -> None: self.process_group = process_group self.overlap = overlap @@ -316,6 +319,14 @@ def __init__( self._forward_prefetch_prerequisites = [] self._forward_overlap_per = self._get_forward_overlap_granularity() self._launch_before_module = self._get_launch_before_module() + # As an optimization, do not release weight after forward for the last + # transformer block since wp would prefetch it immediately + self.layers_wp_not_release = [] # [gpc.config.isp_num_layers - 1] + self.layers_fa_not_release = [ + gpc.config.isp_num_layers - 1, + int(gpc.config.model.checkpoint * gpc.config.isp_num_layers) - 1, + ] + self.sc_offload = selective_ckpt_offload # real overlap state for each chunk. self._overlap_states: Dict[int, ISPOverlapState] = {} @@ -411,6 +422,7 @@ def is_allgather_launch_module(name, module): self._overlap_states[cid].index_to_isp_modules[idx].append(child) setattr(child, "isp_name", name) + setattr(child, "isp_layer_idx", idx) full_name = f"{cid}.{idx}.{name}" setattr( @@ -506,6 +518,25 @@ def _pre_forward_hook_for_prefetch_launch_module(self, module: nn.Module, *args) if block_index + 1 < self._num_blocks: self._all_gather_block_weight(block_index + 1) + # register offload and prefetch hook for selective ckpt with wo linear + if self.sc_offload is True: + # move current layer's attn output from GPU to CPU asynchronizely + if ( + self.is_forward is True + and gpc.config.selective_checkpoint + and block_index not in self.layers_fa_not_release + and block_index < self._ckpt_block_num + ): + get_offload_manager().offload_fa_output_with_layer(layer_idx=block_index) + + # load previous layer's attn output from CPU to GPU asynchronizely + if ( + self.is_forward is False + and gpc.config.selective_checkpoint + and (0 <= (block_index - 1) < self._ckpt_block_num) + ): + get_offload_manager().preload_fa_output_with_layer(layer_idx=block_index - 1) + def _pre_forward_hook_for_module(self, module: nn.Module, *args): # pylint: disable=W0613 if module not in self._weight_global_handle: self._all_gather_module_weight(module) @@ -539,6 +570,9 @@ def _pre_forward_hook_for_module(self, module: nn.Module, *args): # pylint: dis self._all_gather_module_weight(next_module) def _post_forward_hook_for_module(self, module: nn.Module, *args): # pylint: disable=W0613 + if int(module.isp_layer_idx) in self.layers_wp_not_release: + # print(f"the layer {module.isp_layer_idx} after forward not clear weight") + return if not ((self._module_to_index[module] < self._ckpt_block_num) and self.is_forward is False): self._clear_handle(module) self._clear_weight(module) diff --git a/internlm/core/trainer_builder.py b/internlm/core/trainer_builder.py index d0ef284d4..cb45be728 100644 --- a/internlm/core/trainer_builder.py +++ b/internlm/core/trainer_builder.py @@ -11,6 +11,7 @@ from internlm.checkpoint.checkpoint_manager import CheckpointManager from internlm.core.context import global_context as gpc from internlm.core.context.process_group_initializer import ParallelMode +from internlm.core.parallel.comm import initialize_offload_manager from internlm.core.trainer import Trainer from internlm.data.streaming.utils import streaming_simple_resume from internlm.data.train_state import get_train_state @@ -118,6 +119,9 @@ def __init__( # initialize isp communicator isp_communicator = initialize_parallel_communicator(model) + # initialize cpu offload manager for selective checkpoint + initialize_offload_manager(gpc.config.get("selective_checkpoint_offload", False)) + # initialize train state train_state = get_train_state(train_dl) diff --git a/internlm/initialize/launch.py b/internlm/initialize/launch.py index 35b3d646c..580ba0069 100644 --- a/internlm/initialize/launch.py +++ b/internlm/initialize/launch.py @@ -66,6 +66,8 @@ def get_default_parser(): def args_sanity_check(): assert gpc.config is not None, "config is not load!" + gpc.is_forward = True + if "JOB_NAME" not in gpc.config: gpc.config._add_item("JOB_NAME", "AnonymousJob") @@ -73,6 +75,13 @@ def args_sanity_check(): if "model_type" not in gpc.config: gpc.config._add_item("model_type", ModelType.INTERNLM.name) + if gpc.config.model_type == "InternLM3_M": + # TODO: need check for isp overlap + num_layers = gpc.config.model.num_self_decoder_layers + gpc.config.model.num_cross_decoder_layers + else: + num_layers = gpc.config.model.num_layers + gpc.config.isp_num_layers = num_layers + if "use_apex_adam" not in gpc.config: gpc.config._add_item("use_apex_adam", False) @@ -399,17 +408,18 @@ def args_sanity_check(): gpc.config.parallel["tensor"] = dict(size=gpc.config.parallel["tensor"], mode=TensorParallelMode.mtp.name) if gpc.config.parallel["tensor"].get("mode", None) is None: gpc.config.parallel["tensor"]["mode"] = TensorParallelMode.mtp.name - assert ( - gpc.config.VOCAB_SIZE % gpc.config.parallel.tensor.size == 0 - ), "VOCAB_SIZE must be integer multiple of tensor parallel size" if gpc.config.parallel["tensor"]["mode"] == TensorParallelMode.isp.name: assert not gpc.config.parallel.zero1.fsdp, "FSDP does not support isp" assert ( torch.__version__ >= "2.1.0" ), f"requires torch>=2.1.0 when using isp but current version is {torch.__version__}" - assert ( - gpc.config.VOCAB_SIZE % gpc.config.parallel.weight.size == 0 - ), "VOCAB_SIZE must be integer multiple of wp size" + + assert ( + gpc.config.model.vocab_size % gpc.config.parallel.weight.size == 0 + ), "model.vocab_size must be integer multiple of weight parallel size" + assert ( + gpc.config.model.vocab_size % gpc.config.parallel.tensor.size == 0 + ), "model.vocab_size must be integer multiple of tensor parallel size" assert gpc.config.parallel["tensor"].get("mode", None) in [ TensorParallelMode.mtp.name, @@ -532,7 +542,20 @@ def args_sanity_check(): gpc.config.loss._add_item("moe_loss_coeff", 1.0) if "selective_checkpoint" not in gpc.config: - gpc.config._add_item("selective_checkpoint", False) + gpc.config.selective_checkpoint = False + if "selective_checkpoint_offload" not in gpc.config: + gpc.config.selective_checkpoint_offload = False + if gpc.config.selective_checkpoint is True: + assert ( + gpc.config.parallel["tensor"]["mode"] == "isp" + ), "When using selective_checkpoint, tensor parallel mode must be isp" + if gpc.config.selective_checkpoint_offload is True: + assert ( + gpc.config.selective_checkpoint is True + ), "When using selective_checkpoint_offload, selective_checkpoint must be True" + assert ( + gpc.config.parallel.weight.launch_allgather_before == "wo" + ), "When using selective_checkpoint_offload, wp launch allgather communication should be set before 'wo' module" # moe not support overlap and zero1.5 for now if gpc.config.model.get("num_experts", 1) > 1: diff --git a/internlm/model/ops/_flash_attn.py b/internlm/model/ops/_flash_attn.py new file mode 100644 index 000000000..87aac2eb8 --- /dev/null +++ b/internlm/model/ops/_flash_attn.py @@ -0,0 +1,331 @@ +# Copyright (c) InternLM. All rights reserved. +import torch + +from internlm.accelerator import get_accelerator +from internlm.core.context import global_context as gpc +from internlm.core.parallel.comm import get_offload_manager + +try: + import flash_attn + from flash_attn.flash_attn_interface import ( + _flash_attn_varlen_backward, + _flash_attn_varlen_forward, + ) + + gpu_flash_attn_impl = True +except (ModuleNotFoundError, ImportError): + gpu_flash_attn_impl = False + +internlm_accelerator = get_accelerator() +device_backend = internlm_accelerator.get_accelerator_backend() + + +class FlashAttnVarlenKVPackedFunc_V263(torch.autograd.Function): + """ + Varlen KVPacked Func from Flash Attn v2.6.3. + """ + + @staticmethod + def forward( + ctx, + q, + kv, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p, + softmax_scale, + causal, + window_size, + softcap, + alibi_slopes, + deterministic, + return_softmax, + layer_idx, + ): + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + + k, v = kv[:, 0], kv[:, 1] + + _ckpt_block_num = int(gpc.config.model.checkpoint * gpc.config.isp_num_layers) + + if gpc.is_forward is False and gpc.config.selective_checkpoint and layer_idx < _ckpt_block_num: + out, out_padded, softmax_lse, S_dmask, rng_state = get_offload_manager().get_fa_output_with_layer(layer_idx) + else: + ( + out, + q, + k, + v, + out_padded, + softmax_lse, + S_dmask, + rng_state, + ) = _flash_attn_varlen_forward( # pylint: disable=E1123 + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p, + softmax_scale, + causal=causal, + window_size=window_size, + softcap=softcap, + alibi_slopes=alibi_slopes, + return_softmax=return_softmax and dropout_p > 0, + block_table=None, + ) + + # store attn forward output to avoid re-computation of attn when activation checkpoint is enabled + if gpc.is_forward and gpc.config.selective_checkpoint and layer_idx < _ckpt_block_num: + get_offload_manager().insert_fa_output_with_layer( + layer_idx=layer_idx, output=(out, out_padded, softmax_lse, S_dmask, rng_state) + ) + + ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state) + ctx.dropout_p = dropout_p + ctx.max_seqlen_q = max_seqlen_q + ctx.max_seqlen_k = max_seqlen_k + ctx.softmax_scale = softmax_scale + ctx.causal = causal + ctx.window_size = window_size + ctx.softcap = softcap + ctx.alibi_slopes = alibi_slopes + ctx.deterministic = deterministic + return out if not return_softmax else (out, softmax_lse, S_dmask) + + @staticmethod + def backward(ctx, dout, *args): # pylint: disable=W0613 + q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors + dq = torch.empty_like(q) + kv_shape = k.shape[:-2] + (2, *k.shape[-2:]) + dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device) + _flash_attn_varlen_backward( # pylint: disable=E1121,E1124 + dout, + q, + k, + v, + out, + softmax_lse, + dq, + dkv[:, 0], + dkv[:, 1], + cu_seqlens_q, + cu_seqlens_k, + ctx.max_seqlen_q, + ctx.max_seqlen_k, + ctx.dropout_p, + ctx.softmax_scale, + ctx.causal, + ctx.window_size, + ctx.softcap, + ctx.alibi_slopes, + ctx.deterministic, + rng_state=rng_state, + ) + dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension + dkv = dkv[..., : dout.shape[-1]] + return dq, dkv, None, None, None, None, None, None, None, None, None, None, None, None, None + + +class FlashAttnVarlenKVPackedFunc_V221(torch.autograd.Function): + """ + Varlen KVPacked Func from Flash Attn v2.2.1. + """ + + @staticmethod + def forward( + ctx, + q, + kv, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p, + softmax_scale, + causal, + return_softmax, + layer_idx, + ): + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + + k, v = kv[:, 0], kv[:, 1] + + _ckpt_block_num = int(gpc.config.model.checkpoint * gpc.config.isp_num_layers) + + if gpc.is_forward is False and gpc.config.selective_checkpoint and layer_idx < _ckpt_block_num: + out, out_padded, softmax_lse, S_dmask, rng_state = get_offload_manager().get_fa_output_with_layer(layer_idx) + else: + out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p, + softmax_scale, + causal=causal, + return_softmax=return_softmax and dropout_p > 0, + ) + + # store attn forward output to avoid re-computation of attn when activation checkpoint is enabled + if gpc.is_forward and gpc.config.selective_checkpoint and layer_idx < _ckpt_block_num: + get_offload_manager().insert_fa_output_with_layer( + layer_idx=layer_idx, output=(out, out_padded, softmax_lse, S_dmask, rng_state) + ) + ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state) + ctx.dropout_p = dropout_p + ctx.max_seqlen_q = max_seqlen_q + ctx.max_seqlen_k = max_seqlen_k + ctx.softmax_scale = softmax_scale + ctx.causal = causal + return out if not return_softmax else (out, softmax_lse, S_dmask) + + @staticmethod + def backward(ctx, dout, *args): # pylint: disable=W0613 + q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors + dq = torch.empty_like(q) + kv_shape = k.shape[:-2] + (2, *k.shape[-2:]) + dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device) + _flash_attn_varlen_backward( + dout, + q, + k, + v, + out, + softmax_lse, + dq, + dkv[:, 0], + dkv[:, 1], + cu_seqlens_q, + cu_seqlens_k, + ctx.max_seqlen_q, + ctx.max_seqlen_k, + ctx.dropout_p, + ctx.softmax_scale, + ctx.causal, + rng_state=rng_state, + ) + dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension + dkv = dkv[..., : dout.shape[-1]] + return dq, dkv, None, None, None, None, None, None, None, None, None + + +def flash_attn_varlen_kvpacked_func( + q, + kv, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), # -1 means infinite context window + softcap=0.0, # 0.0 means deactivated + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + layer_idx=0, +): + """dropout_p should be set to 0.0 during evaluation + If K, V are already stacked into 1 tensor, this function will be faster than + calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation + of the gradients of K, V. + Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads + than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. + For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head + 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. + If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix. + For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is: + 1 1 1 1 0 + 1 1 1 1 1 + If seqlen_q = 5 and seqlen_k = 2, the causal mask is: + 0 0 + 0 0 + 0 0 + 1 0 + 1 1 + If the row of the mask is all zero, the output will be zero. + If window_size != (-1, -1), implements sliding window local attention. Query at position i + will only attend to keys between + [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. + Arguments: + q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch. + kv: (total_k, 2, nheads_k, headdim), where total_k = total number of key tokens in the batch. + cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths + of the sequences in the batch, used to index into q. + cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths + of the sequences in the batch, used to index into kv. + max_seqlen_q: int. Maximum query sequence length in the batch. + max_seqlen_k: int. Maximum key sequence length in the batch. + dropout_p: float. Dropout probability. + softmax_scale: float. The scaling of QK^T before applying softmax. + Default to 1 / sqrt(headdim). + causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). + window_size: (left, right). If not (-1, -1), implements sliding window local attention. + softcap: float. Anything > 0 activates softcapping attention. + alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of + (-alibi_slope * |i + seqlen_k - seqlen_q - j|) + is added to the attention score of query i and key j. + deterministic: bool. Whether to use the deterministic implementation of the backward pass, + which is slightly slower and uses more memory. The forward pass is always deterministic. + return_attn_probs: bool. Whether to return the attention probabilities. This option is for + testing only. The returned probabilities are not guaranteed to be correct + (they might not have the right scaling). + Return: + out: (total, nheads, headdim). + softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The + logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax + normalization factor). + S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen). + The output of softmax (possibly with different scaling). It also encodes the dropout + pattern (negative means that location was dropped, nonnegative means it was kept). + """ + + assert gpu_flash_attn_impl is True and flash_attn.__version__ in [ + "2.2.1", + "2.6.3", + ], "flash-attn should be installed and version must be v2.2.1 or v2.6.3" + + if flash_attn.__version__ == "2.2.1": + return FlashAttnVarlenKVPackedFunc_V221.apply( + q, + kv, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p, + softmax_scale, + causal, + return_attn_probs, + layer_idx, + ) + + return FlashAttnVarlenKVPackedFunc_V263.apply( + q, + kv, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p, + softmax_scale, + causal, + window_size, + softcap, + alibi_slopes, + deterministic, + return_attn_probs, + layer_idx, + ) diff --git a/internlm/model/ops/attention.py b/internlm/model/ops/attention.py index 604ea77a7..3aec51f55 100644 --- a/internlm/model/ops/attention.py +++ b/internlm/model/ops/attention.py @@ -93,13 +93,14 @@ from flash_attn.flash_attn_interface import ( flash_attn_varlen_func as _flash_varlen_qkvsplited_func, ) - from flash_attn.flash_attn_interface import ( - flash_attn_varlen_kvpacked_func as _flash_varlen_kvpacked_func, - ) from flash_attn.flash_attn_interface import ( flash_attn_varlen_qkvpacked_func as _flash_varlen_qkvpacked_func, ) + from ._flash_attn import ( + flash_attn_varlen_kvpacked_func as _flash_varlen_kvpacked_func, + ) + gpu_flash_attn_impl = True except (ModuleNotFoundError, ImportError): gpu_flash_attn_impl = False @@ -187,6 +188,7 @@ def _flash_varlen_kvpacked_attn( dropout_p=0.0, softmax_scale=None, causal=False, + layer_idx=0, ): # compatible data format: [1, packelen, 3, n_head, headim] q, kv = q.squeeze(dim=0), kv.squeeze(dim=0) @@ -204,6 +206,7 @@ def _flash_varlen_kvpacked_attn( dropout_p, softmax_scale, causal, + layer_idx=layer_idx, ) return output.unsqueeze(dim=0) @@ -521,6 +524,7 @@ def _npu_varlen_kvpacked_attn( dropout_p=0.0, softmax_scale=None, causal=False, + layer_idx=0, # pylint: disable=W0613 ): # TODO: support npu native varlen flash attention k, v = kv.unbind(dim=2) @@ -579,6 +583,7 @@ def _deeplink_varlen_kvpacked_attn( dropout_p=0.0, softmax_scale=None, causal=False, + layer_idx=0, # pylint: disable=W0613 ): # compatible data format: [1, packelen, 3, n_head, headim] q, kv = q.squeeze(dim=0), kv.squeeze(dim=0) @@ -1012,7 +1017,17 @@ def _q_kv_with_cu_seqlens( extra_args = (key_padding_mask,) if attn_type is AttnType.Torch else () return op( - q, kv, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout, softmax_scale, causal, *extra_args + q, + kv, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout, + softmax_scale, + causal, + *extra_args, + layer_idx=self.layer_idx, ) @forward.register(conditions=(str(QKVPackType.QKVSPLITED), str(CuSeqlenType.With))) diff --git a/internlm/train/pipeline.py b/internlm/train/pipeline.py index 79e9caf43..ca11e6898 100644 --- a/internlm/train/pipeline.py +++ b/internlm/train/pipeline.py @@ -363,6 +363,7 @@ def initialize_parallel_communicator(model: Union[nn.Module, nn.ModuleList]): gpc.config.parallel.weight.overlap, gpc.get_group(ParallelMode.WEIGHT), is_moe=False, + selective_ckpt_offload=gpc.config.get("selective_checkpoint_offload", False), ) # register communicator for isp column parallel linear. ColumnParallelLinear.register_cls_communicator(isp_communicator)