Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions internlm/core/parallel/comm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .attn_offload import get_offload_manager, initialize_offload_manager

__all__ = ["initialize_offload_manager", "get_offload_manager"]
127 changes: 127 additions & 0 deletions internlm/core/parallel/comm/attn_offload.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
import torch

from internlm.utils.common import get_current_device

global_attn_offload = None


class AttnOffloadManager:
"""
A manager for attention output CPU offloading and GPU prefetch loading.
"""

def __init__(self, enable_cpu_offload: bool = False) -> None:
# cpu offload overlapping
self.cpu_offload = enable_cpu_offload
# layer id mapping to flash attn output
self.fa_output_mapping = {}
self.fa_stream = torch.cuda.Stream()
self.d2h_final_event = torch.cuda.Event()
self.h2d_final_event = torch.cuda.Event()
# prepare for tensor buffer
self.tensor_id_to_tensor_bufs = {}

def get_tensor_buf_for_offloaded_tensor(self, tensor, layer_id, tensor_id):
"""Get tensor buffer for offloaded tensor."""
layer_id = layer_id % 2
if layer_id not in self.tensor_id_to_tensor_bufs:
self.tensor_id_to_tensor_bufs[layer_id] = {}

if tensor_id not in self.tensor_id_to_tensor_bufs[layer_id]:
allocate_new_buf = True
else:
tensor_buf = self.tensor_id_to_tensor_bufs[layer_id][tensor_id]
allocate_new_buf = tensor_buf.size() == tensor.size() and tensor_buf.dtype == tensor.dtype

if allocate_new_buf:
# supposed to only execute once
buffer = torch.empty(
tensor.size(),
dtype=tensor.dtype,
layout=tensor.layout,
device=tensor.device,
)

self.tensor_id_to_tensor_bufs[layer_id][tensor_id] = buffer

return self.tensor_id_to_tensor_bufs[layer_id][tensor_id]

def insert_fa_output_with_layer(self, layer_idx, output):
assert layer_idx not in self.fa_output_mapping
if self.cpu_offload is False:
self.fa_output_mapping[layer_idx] = output
return

tensors = []
for tensor_id, tensor in enumerate(output):
if tensor is None:
tensors.append(None)
continue
tensor_buf = self.get_tensor_buf_for_offloaded_tensor(tensor, layer_idx, tensor_id)
tensor_buf.copy_(tensor)
tensors.append(tensor_buf)
self.fa_output_mapping[layer_idx] = tensors

def get_fa_output_with_layer(self, layer_idx):
assert layer_idx in self.fa_output_mapping
return self.fa_output_mapping.pop(layer_idx)

def offload_fa_output_with_layer(self, layer_idx):
assert layer_idx in self.fa_output_mapping

self.fa_stream.wait_stream(torch.cuda.current_stream())
self.fa_stream.wait_event(self.d2h_final_event)

with torch.cuda.stream(self.fa_stream):
_gpu_tensors = self.fa_output_mapping.pop(layer_idx)
_cpu_tensors = []
for _tensor in _gpu_tensors:
if _tensor is None:
_cpu_tensors.append(_tensor)
continue

_cpu_backup = torch.empty(
_tensor.size(),
dtype=_tensor.dtype,
layout=_tensor.layout,
device="cpu",
pin_memory=True,
)
_cpu_backup.copy_(_tensor, non_blocking=True)
_cpu_tensors.append(_cpu_backup)

# _cpu_tensors.append(_tensor.to("cpu", non_blocking=False))

self.fa_output_mapping[layer_idx] = _cpu_tensors

self.fa_stream.record_event(self.d2h_final_event)

def preload_fa_output_with_layer(self, layer_idx):
assert layer_idx in self.fa_output_mapping

self.fa_stream.wait_stream(torch.cuda.current_stream())
self.fa_stream.wait_event(self.h2d_final_event)

# Important: get device before with stream, in stream get device is error
_device = get_current_device()
with torch.cuda.stream(self.fa_stream):
_cpu_tensors = self.fa_output_mapping.pop(layer_idx)
self.fa_output_mapping[layer_idx] = [
_tensor.to(device=_device, non_blocking=True) if _tensor is not None else _tensor
for _tensor in _cpu_tensors
]

self.fa_stream.record_event(self.h2d_final_event)


def initialize_offload_manager(enable_cpu_offload: bool = False):
global global_attn_offload
if global_attn_offload is None:
global_attn_offload = AttnOffloadManager(enable_cpu_offload)

return global_attn_offload


def get_offload_manager():
assert global_attn_offload is not None
return global_attn_offload
34 changes: 34 additions & 0 deletions internlm/core/parallel/comm/isp.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@
params_dispatch_with_condition,
)

from .attn_offload import get_offload_manager


# not really useful, only for code hint.
class WPCommunicator(ABC):
Expand Down Expand Up @@ -306,6 +308,7 @@ def __init__(
overlap: bool = False,
process_group: dist.ProcessGroup = None,
is_moe: bool = False,
selective_ckpt_offload: bool = False,
) -> None:
self.process_group = process_group
self.overlap = overlap
Expand All @@ -316,6 +319,14 @@ def __init__(
self._forward_prefetch_prerequisites = []
self._forward_overlap_per = self._get_forward_overlap_granularity()
self._launch_before_module = self._get_launch_before_module()
# As an optimization, do not release weight after forward for the last
# transformer block since wp would prefetch it immediately
self.layers_wp_not_release = [] # [gpc.config.isp_num_layers - 1]
self.layers_fa_not_release = [
gpc.config.isp_num_layers - 1,
int(gpc.config.model.checkpoint * gpc.config.isp_num_layers) - 1,
]
self.sc_offload = selective_ckpt_offload

# real overlap state for each chunk.
self._overlap_states: Dict[int, ISPOverlapState] = {}
Expand Down Expand Up @@ -411,6 +422,7 @@ def is_allgather_launch_module(name, module):
self._overlap_states[cid].index_to_isp_modules[idx].append(child)

setattr(child, "isp_name", name)
setattr(child, "isp_layer_idx", idx)

full_name = f"{cid}.{idx}.{name}"
setattr(
Expand Down Expand Up @@ -506,6 +518,25 @@ def _pre_forward_hook_for_prefetch_launch_module(self, module: nn.Module, *args)
if block_index + 1 < self._num_blocks:
self._all_gather_block_weight(block_index + 1)

# register offload and prefetch hook for selective ckpt with wo linear
if self.sc_offload is True:
# move current layer's attn output from GPU to CPU asynchronizely
if (
self.is_forward is True
and gpc.config.selective_checkpoint
and block_index not in self.layers_fa_not_release
and block_index < self._ckpt_block_num
):
get_offload_manager().offload_fa_output_with_layer(layer_idx=block_index)

# load previous layer's attn output from CPU to GPU asynchronizely
if (
self.is_forward is False
and gpc.config.selective_checkpoint
and (0 <= (block_index - 1) < self._ckpt_block_num)
):
get_offload_manager().preload_fa_output_with_layer(layer_idx=block_index - 1)

def _pre_forward_hook_for_module(self, module: nn.Module, *args): # pylint: disable=W0613
if module not in self._weight_global_handle:
self._all_gather_module_weight(module)
Expand Down Expand Up @@ -539,6 +570,9 @@ def _pre_forward_hook_for_module(self, module: nn.Module, *args): # pylint: dis
self._all_gather_module_weight(next_module)

def _post_forward_hook_for_module(self, module: nn.Module, *args): # pylint: disable=W0613
if int(module.isp_layer_idx) in self.layers_wp_not_release:
# print(f"the layer {module.isp_layer_idx} after forward not clear weight")
return
if not ((self._module_to_index[module] < self._ckpt_block_num) and self.is_forward is False):
self._clear_handle(module)
self._clear_weight(module)
Expand Down
4 changes: 4 additions & 0 deletions internlm/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from internlm.checkpoint.checkpoint_manager import CheckpointManager
from internlm.core.context import global_context as gpc
from internlm.core.context.process_group_initializer import ParallelMode
from internlm.core.parallel.comm import initialize_offload_manager
from internlm.core.trainer import Trainer
from internlm.data.streaming.utils import streaming_simple_resume
from internlm.data.train_state import get_train_state
Expand Down Expand Up @@ -118,6 +119,9 @@ def __init__(
# initialize isp communicator
isp_communicator = initialize_parallel_communicator(model)

# initialize cpu offload manager for selective checkpoint
initialize_offload_manager(gpc.config.get("selective_checkpoint_offload", False))

# initialize train state
train_state = get_train_state(train_dl)

Expand Down
37 changes: 30 additions & 7 deletions internlm/initialize/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,13 +66,22 @@ def get_default_parser():
def args_sanity_check():
assert gpc.config is not None, "config is not load!"

gpc.is_forward = True

if "JOB_NAME" not in gpc.config:
gpc.config._add_item("JOB_NAME", "AnonymousJob")

# the default model type is INTERNLM
if "model_type" not in gpc.config:
gpc.config._add_item("model_type", ModelType.INTERNLM.name)

if gpc.config.model_type == "InternLM3_M":
# TODO: need check for isp overlap
num_layers = gpc.config.model.num_self_decoder_layers + gpc.config.model.num_cross_decoder_layers
else:
num_layers = gpc.config.model.num_layers
gpc.config.isp_num_layers = num_layers

if "use_apex_adam" not in gpc.config:
gpc.config._add_item("use_apex_adam", False)

Expand Down Expand Up @@ -399,17 +408,18 @@ def args_sanity_check():
gpc.config.parallel["tensor"] = dict(size=gpc.config.parallel["tensor"], mode=TensorParallelMode.mtp.name)
if gpc.config.parallel["tensor"].get("mode", None) is None:
gpc.config.parallel["tensor"]["mode"] = TensorParallelMode.mtp.name
assert (
gpc.config.VOCAB_SIZE % gpc.config.parallel.tensor.size == 0
), "VOCAB_SIZE must be integer multiple of tensor parallel size"
if gpc.config.parallel["tensor"]["mode"] == TensorParallelMode.isp.name:
assert not gpc.config.parallel.zero1.fsdp, "FSDP does not support isp"
assert (
torch.__version__ >= "2.1.0"
), f"requires torch>=2.1.0 when using isp but current version is {torch.__version__}"
assert (
gpc.config.VOCAB_SIZE % gpc.config.parallel.weight.size == 0
), "VOCAB_SIZE must be integer multiple of wp size"

assert (
gpc.config.model.vocab_size % gpc.config.parallel.weight.size == 0
), "model.vocab_size must be integer multiple of weight parallel size"
assert (
gpc.config.model.vocab_size % gpc.config.parallel.tensor.size == 0
), "model.vocab_size must be integer multiple of tensor parallel size"

assert gpc.config.parallel["tensor"].get("mode", None) in [
TensorParallelMode.mtp.name,
Expand Down Expand Up @@ -532,7 +542,20 @@ def args_sanity_check():
gpc.config.loss._add_item("moe_loss_coeff", 1.0)

if "selective_checkpoint" not in gpc.config:
gpc.config._add_item("selective_checkpoint", False)
gpc.config.selective_checkpoint = False
if "selective_checkpoint_offload" not in gpc.config:
gpc.config.selective_checkpoint_offload = False
if gpc.config.selective_checkpoint is True:
assert (
gpc.config.parallel["tensor"]["mode"] == "isp"
), "When using selective_checkpoint, tensor parallel mode must be isp"
if gpc.config.selective_checkpoint_offload is True:
assert (
gpc.config.selective_checkpoint is True
), "When using selective_checkpoint_offload, selective_checkpoint must be True"
assert (
gpc.config.parallel.weight.launch_allgather_before == "wo"
), "When using selective_checkpoint_offload, wp launch allgather communication should be set before 'wo' module"

# moe not support overlap and zero1.5 for now
if gpc.config.model.get("num_experts", 1) > 1:
Expand Down
Loading
Loading