Skip to content

Commit 086fd3e

Browse files
authored
Fix/fix broadcast overlap with isp (#64)
1 parent ce31407 commit 086fd3e

File tree

6 files changed

+191
-132
lines changed

6 files changed

+191
-132
lines changed

internlm/core/communication/isp.py

Lines changed: 52 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from dataclasses import dataclass
55
from functools import partial
6-
from typing import Any, Dict, List, Union
6+
from typing import Any, Callable, Dict, List, Union
77

88
import torch
99
from torch import distributed as dist
@@ -135,7 +135,8 @@ def __init__(self) -> None:
135135
self.last_ckpt_block: nn.Module = None
136136
self.isp_outs: List[nn.Module] = []
137137
self.isp_modules: List[nn.Module] = []
138-
self.index_to_isp_module: Dict[int, nn.Module] = {}
138+
self.index_to_isp_modules: Dict[int, nn.Module] = {}
139+
self.index_to_block: Dict[int, nn.Module] = {}
139140
self.module_to_index: Dict[nn.Module, int] = {}
140141
self.weight_global_handle: Dict[str, Any] = {}
141142
self.weight_global_output: Dict[str, torch.Tensor] = {}
@@ -163,6 +164,7 @@ def __init__(
163164
self.is_forward = True
164165
self.reduce_scatter_handlers = {}
165166
self._module_shapes = {}
167+
self._forward_prefetch_prerequisites = []
166168

167169
# real overlap state for each chunk.
168170
self._overlap_states: Dict[int, ISPOverlapState] = {}
@@ -186,7 +188,9 @@ def __init__(
186188
# key: isp module; value: transformer block index
187189
self._module_to_index = None
188190
# key: transformer block index; value: isp modules
189-
self._index_to_isp_module = None
191+
self._index_to_isp_modules = None
192+
# key: transformer block index; value: transformer block
193+
self._index_to_block = None
190194

191195
# init overlap states if necessary.
192196
if self.overlap:
@@ -228,7 +232,8 @@ def _parse_model_structure(self, cid: int, model: nn.Module) -> None:
228232
]
229233

230234
for idx, block in enumerate(children):
231-
self._overlap_states[cid].index_to_isp_module[idx] = []
235+
self._overlap_states[cid].index_to_isp_modules[idx] = []
236+
self._overlap_states[cid].index_to_block[idx] = block
232237
for sub_name, sub in block.named_children():
233238
for name, child in sub.named_children():
234239
if name in ["out_proj", "wo"]:
@@ -243,7 +248,7 @@ def _parse_model_structure(self, cid: int, model: nn.Module) -> None:
243248
self._module_shapes[name] = torch.Size(origin_shape)
244249
self._overlap_states[cid].module_to_index[child] = idx
245250
self._overlap_states[cid].isp_modules.append(child)
246-
self._overlap_states[cid].index_to_isp_module[idx].append(child)
251+
self._overlap_states[cid].index_to_isp_modules[idx].append(child)
247252

248253
setattr(child, "isp_name", name)
249254

@@ -260,7 +265,7 @@ def _parse_model_structure(self, cid: int, model: nn.Module) -> None:
260265
f"{full_name}.bias",
261266
)
262267

263-
self._overlap_states[cid].num_blocks = len(self._overlap_states[cid].index_to_isp_module)
268+
self._overlap_states[cid].num_blocks = len(self._overlap_states[cid].index_to_isp_modules)
264269

265270
def _all_gather_module_weight(self, module):
266271
with_bias = module.bias is not None
@@ -307,7 +312,15 @@ def _all_gather_module_weight(self, module):
307312
self._weight_global_output[module] = weight_output
308313

309314
def _all_gather_block_weight(self, block_index: int):
310-
for module in self._index_to_isp_module[block_index]:
315+
block = self._index_to_block[block_index]
316+
317+
# wait for prerequisite conditions
318+
if self.is_forward:
319+
for callback in self._forward_prefetch_prerequisites:
320+
callback(block)
321+
322+
# prefetch parameters for all isp modules of the block
323+
for module in self._index_to_isp_modules[block_index]:
311324
self._all_gather_module_weight(module)
312325

313326
def _wait_handle(self, module):
@@ -358,7 +371,7 @@ def _pre_forward_hook_for_module(self, module: nn.Module, *args): # pylint: dis
358371
self._wait_handle(module)
359372

360373
def _pre_forward_hook_for_block(self, *args): # pylint: disable=W0613
361-
for module in self._index_to_isp_module[self._ckpt_block_num - 1]:
374+
for module in self._index_to_isp_modules[self._ckpt_block_num - 1]:
362375
self._all_gather_module_weight(module)
363376

364377
def _post_forward_hook_for_module(self, module: nn.Module, *args): # pylint: disable=W0613
@@ -446,13 +459,41 @@ def switch_current_model_chunk(self, chunk_id: int) -> None:
446459
self._weight_global_output = self._overlap_states[chunk_id].weight_global_output
447460
self._bias_global_output = self._overlap_states[chunk_id].bias_global_output
448461
self._module_to_index = self._overlap_states[chunk_id].module_to_index
449-
self._index_to_isp_module = self._overlap_states[chunk_id].index_to_isp_module
462+
self._index_to_isp_modules = self._overlap_states[chunk_id].index_to_isp_modules
463+
self._index_to_block = self._overlap_states[chunk_id].index_to_block
450464
self._ckpt_block_num = self._overlap_states[chunk_id].ckpt_block_num
451465
self._last_ckpt_block = self._overlap_states[chunk_id].last_ckpt_block
452466
self._head = self._overlap_states[chunk_id].head
453467
self._embedding = self._overlap_states[chunk_id].embedding
454468
self._num_blocks = self._overlap_states[chunk_id].num_blocks
455469

470+
def register_prerequisite_for_forward_prefetch_hooks(self, prerequisite_func: Callable) -> None:
471+
"""
472+
Registers a callback function that specifies a prerequisite condition for
473+
prefetching parameters before forward computation.
474+
475+
This method allows users to define custom logic that must be satisfied before
476+
parameters are fetched for the next forward pass. This can be useful for
477+
implementing complex parameter update strategies or for coordinating
478+
parameter access with other system components.
479+
480+
Args:
481+
prerequisite_func (Callable): A callable that represents the prerequisite
482+
condition. This function will be invoked before
483+
the parameters are prefetched, and its return value
484+
will determine whether the prefetching should proceed.
485+
486+
Returns:
487+
None: This method does not return any value.
488+
489+
Raises:
490+
TypeError: If the provided 'prerequisite_func' is not callable.
491+
"""
492+
if not callable(prerequisite_func):
493+
raise TypeError("The provided prerequisite function must be callable.")
494+
495+
self._forward_prefetch_prerequisites.append(prerequisite_func)
496+
456497
# communication operation interfaces
457498

458499
def all_gather(self, tensor: torch.Tensor, module: nn.Module, is_bias: bool = False):
@@ -521,8 +562,7 @@ def __init__(self, overlap_handler: ISPCommunicator, zero_optim) -> None:
521562
self._zero_optim = zero_optim
522563

523564
def before_forward(self, scheduler, inputs) -> None:
524-
if self._isp_communicator._ckpt_block_num > 0:
525-
self._isp_communicator.is_forward = True
565+
self._isp_communicator.is_forward = True
526566
# switch model chunk before forward
527567
chunk_id = 0 if gpc.virtual_pipeline_parallel_rank is None else gpc.virtual_pipeline_parallel_rank
528568
self._isp_communicator.switch_current_model_chunk(chunk_id)
@@ -537,8 +577,7 @@ def after_criterion(self, scheduler, loss) -> None:
537577
pass
538578

539579
def before_backward(self, scheduler, outputs, outputs_grad) -> None:
540-
if self._isp_communicator._ckpt_block_num > 0:
541-
self._isp_communicator.is_forward = False
580+
self._isp_communicator.is_forward = False
542581
# switch model chunk before backward
543582
chunk_id = 0 if gpc.virtual_pipeline_parallel_rank is None else gpc.virtual_pipeline_parallel_rank
544583
self._isp_communicator.switch_current_model_chunk(chunk_id)

internlm/core/communication/utils.py

Lines changed: 107 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,19 @@
11
# adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/communication
22

3-
from typing import List, Tuple, Union
3+
from collections import OrderedDict
4+
from typing import Dict, List, Tuple, Union
45

56
import torch
67
import torch.distributed as dist
8+
from flash_attn.modules.embedding import ParallelGPT2Embeddings
9+
from torch import nn
710

11+
from internlm.core.communication.isp import ISPCommunicator
812
from internlm.core.context import ParallelMode
913
from internlm.core.context import global_context as gpc
14+
from internlm.core.naive_amp import NaiveAMPModel
15+
from internlm.model.embedding import Embedding1D
16+
from internlm.model.linear import BaseScaleColumnParallelLinear
1017
from internlm.utils.common import get_current_device
1118

1219
TensorShape = Union[torch.Size, List[int], Tuple[int]]
@@ -123,3 +130,102 @@ def gather_split_1d_tensor(tensor: torch.Tensor) -> torch.Tensor:
123130
chunks = [gathered[i * numel : (i + 1) * numel] for i in range(world_size)]
124131
dist.all_gather(chunks, tensor, group=gpc.get_group(ParallelMode.TENSOR))
125132
return gathered
133+
134+
135+
class ParamAsyncBcastHandler:
136+
"""
137+
Model Partition Handler for overlap broadcast with forward
138+
"""
139+
140+
def __init__(
141+
self, zero1_mode: ParallelMode, model: Union[nn.Module, nn.ModuleList], isp_communicator: ISPCommunicator = None
142+
) -> None:
143+
self._block_to_param: Dict[nn.Module, List[nn.Parameter]] = OrderedDict()
144+
self._param_to_rank: Dict[nn.Parameter, int] = {}
145+
self._block_to_rank: Dict[nn.Module, int] = {}
146+
self._bcast_handles: Dict[int, List[dist.Work]] = {}
147+
148+
zero1_size = gpc.get_world_size(zero1_mode)
149+
total_param_num = sum(p.numel() for p in model.parameters())
150+
avg_param_num = total_param_num * 1.0 // zero1_size
151+
152+
# initialize an empty list for _bcast_handles of each rank
153+
self._bcast_handles = {rank: [] for rank in range(zero1_size)}
154+
155+
# just want to share same for loop for ModuleList and Module
156+
if not isinstance(model, nn.ModuleList):
157+
model = [model]
158+
159+
# record the parameters to transformer/embeding/head/norm block
160+
for _chunk in model:
161+
if isinstance(_chunk, NaiveAMPModel):
162+
_chunk = _chunk.model
163+
164+
for _, children in _chunk.named_children():
165+
# should be the transformer block definaton in modeling_xxx.py
166+
if isinstance(children, nn.ModuleList):
167+
# record the block that a parameter belongs to
168+
for _, block in enumerate(children):
169+
# self._block_to_param[f"{name}.{idx}"] = list(block.parameters())
170+
self._block_to_param[block] = list(block.parameters())
171+
else:
172+
# record the block that a parameter belongs to
173+
# self._block_to_param[name] = list(children.parameters())
174+
self._block_to_param[children] = list(children.parameters())
175+
176+
alloc_num = 0
177+
rank_to_go = 0
178+
179+
# process the parameters in block_to_param sequencially,
180+
# allocate each parameter to a local rank of ParallelMode.ZERO1,
181+
# NOTE that we do NOT consider following scenarios:
182+
# 1) whether a parameter is trainable;
183+
# 2) paramters maybe in different optimizer group
184+
for block, params in self._block_to_param.items():
185+
# allocate a model block to a local rank of ParallelMode.ZERO1
186+
self._block_to_rank[block] = [rank_to_go]
187+
for p in params:
188+
alloc_num = alloc_num + p.numel()
189+
# in this case, allocate the param to next rank if possible
190+
if alloc_num > avg_param_num * 1.01 and rank_to_go < zero1_size - 1:
191+
rank_to_go = rank_to_go + 1
192+
alloc_num = 0
193+
self._block_to_rank[block].append(rank_to_go)
194+
# allocate a parameter to a local rank of ParallelMode.ZERO1
195+
self._param_to_rank[p] = rank_to_go
196+
197+
# register_forward_pre_hook for transformer/embeding/norm/xxx block
198+
self._register_sync_parameters_hook(isp_communicator)
199+
200+
def _register_sync_parameters_hook(self, isp_communicator: ISPCommunicator = None) -> None:
201+
def _pre_forward_hook(model: nn.Module, *args, **kwargs): # pylint: disable=W0613
202+
bcast_handles = []
203+
# gather all required broadcast hanles into a list
204+
for rank in self._block_to_rank[model]:
205+
bcast_handles.extend(self._bcast_handles[rank])
206+
# need to clear _bcast_handles since they would be processed later
207+
self._bcast_handles[rank] = []
208+
# wait all required broadcast handles to be completed
209+
for handle in bcast_handles:
210+
handle.wait()
211+
212+
# register_forward_pre_hook for transformer/embeding/norm/xxx block
213+
for block, _ in self._block_to_rank.items():
214+
# TODO: remove special handling for embedding and head layers,
215+
# instead implement support for weight parallelism of embedding and head layers within the ISP.
216+
217+
# NOTE: Although the layernorm layer does not have explicit processing,
218+
# both ISPCommunicator and ParamAsyncBcastHandler handle transformer blocks as granularity,
219+
# so everything is fine.
220+
if isp_communicator is None or isinstance(
221+
block, (Embedding1D, ParallelGPT2Embeddings, BaseScaleColumnParallelLinear)
222+
):
223+
block.register_forward_pre_hook(_pre_forward_hook)
224+
else:
225+
isp_communicator.register_prerequisite_for_forward_prefetch_hooks(_pre_forward_hook)
226+
227+
def get_rank_by_param(self, param) -> int:
228+
return self._param_to_rank[param]
229+
230+
def add_bcast_handle(self, rank, handle) -> None:
231+
self._bcast_handles[rank].append(handle)

internlm/solver/optimizer/hybrid_zero_optim.py

Lines changed: 25 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,14 @@
33

44
import math
55
from functools import partial
6+
from itertools import product
67
from typing import List, Optional
78

89
import torch
910
import torch.distributed as dist
1011
from torch.optim import Optimizer
1112

13+
from internlm.core.communication.utils import ParamAsyncBcastHandler
1214
from internlm.core.context import IS_REPLICA_ZERO_PARALLEL, Config, ParallelMode
1315
from internlm.core.context import global_context as gpc
1416
from internlm.core.context.parallel_context import (
@@ -26,7 +28,6 @@
2628
)
2729
from internlm.solver.optimizer.utils import (
2830
DynamicGradScaler,
29-
ParamBcastSyncHandler,
3031
flatten,
3132
get_grad_accumulate_object,
3233
has_inf_or_nan,
@@ -66,7 +67,7 @@ def __init__(
6667
cpu_offload=False,
6768
grad_scal_cfg: Config = None,
6869
zero_cfg: Config = None,
69-
param_bcast_sync_handler: ParamBcastSyncHandler = None,
70+
param_bcast_sync_handler: ParamAsyncBcastHandler = None,
7071
isp_communicator=None,
7172
):
7273
# DynamicGradScaler related args
@@ -1053,26 +1054,29 @@ def _step(self, closure=None, norms=None):
10531054
def broadcast_params(self):
10541055
handles = []
10551056

1056-
for group_id in range(self.num_param_groups):
1057-
for rank in range(self._zero_world_size[group_id]):
1058-
# The following operations are performed only on the rank to which parameters are assigned.
1059-
if rank in self.param_group_no_params_ranks[group_id]:
1060-
continue
1061-
fp16_param = self._param_store.get_flat_fp16_param_by_rank_group(rank=rank, group_id=group_id)
1062-
# grank = gpc.get_ranks_in_group(group_type)[rank] # need to convert to the global rank
1063-
# assert grank == rank, f"{grank} == {rank}"
1064-
g_rank = gpc.get_ranks_in_group(self._broadcast_parallel_mode[group_id])[rank]
1065-
handle = dist.broadcast(
1066-
fp16_param,
1067-
src=g_rank,
1068-
group=gpc.get_group(self._broadcast_parallel_mode[group_id]),
1069-
async_op=True,
1070-
)
1057+
# traverse according to rank firstly, which is conducive to overlapping broadcast communication.
1058+
for rank, group_id in product(range(max(self._zero_world_size)), range(self.num_param_groups)):
1059+
# skip ranks not in this parameter group.
1060+
if rank >= self._zero_world_size[group_id]:
1061+
continue
1062+
# The following operations are performed only on the rank to which parameters are assigned.
1063+
if rank in self.param_group_no_params_ranks[group_id]:
1064+
continue
1065+
fp16_param = self._param_store.get_flat_fp16_param_by_rank_group(rank=rank, group_id=group_id)
1066+
# grank = gpc.get_ranks_in_group(group_type)[rank] # need to convert to the global rank
1067+
# assert grank == rank, f"{grank} == {rank}"
1068+
g_rank = gpc.get_ranks_in_group(self._broadcast_parallel_mode[group_id])[rank]
1069+
handle = dist.broadcast(
1070+
fp16_param,
1071+
src=g_rank,
1072+
group=gpc.get_group(self._broadcast_parallel_mode[group_id]),
1073+
async_op=True,
1074+
)
10711075

1072-
if self._overlap_sync_param:
1073-
self._param_bcast_sync_handler.add_bcast_handle(rank, handle)
1074-
else:
1075-
handles.append(handle)
1076+
if self._overlap_sync_param:
1077+
self._param_bcast_sync_handler.add_bcast_handle(rank, handle)
1078+
else:
1079+
handles.append(handle)
10761080

10771081
for handle in handles:
10781082
handle.wait()

0 commit comments

Comments
 (0)