3
3
4
4
from dataclasses import dataclass
5
5
from functools import partial
6
- from typing import Any , Dict , List , Union
6
+ from typing import Any , Callable , Dict , List , Union
7
7
8
8
import torch
9
9
from torch import distributed as dist
@@ -135,7 +135,8 @@ def __init__(self) -> None:
135
135
self .last_ckpt_block : nn .Module = None
136
136
self .isp_outs : List [nn .Module ] = []
137
137
self .isp_modules : List [nn .Module ] = []
138
- self .index_to_isp_module : Dict [int , nn .Module ] = {}
138
+ self .index_to_isp_modules : Dict [int , nn .Module ] = {}
139
+ self .index_to_block : Dict [int , nn .Module ] = {}
139
140
self .module_to_index : Dict [nn .Module , int ] = {}
140
141
self .weight_global_handle : Dict [str , Any ] = {}
141
142
self .weight_global_output : Dict [str , torch .Tensor ] = {}
@@ -163,6 +164,7 @@ def __init__(
163
164
self .is_forward = True
164
165
self .reduce_scatter_handlers = {}
165
166
self ._module_shapes = {}
167
+ self ._forward_prefetch_prerequisites = []
166
168
167
169
# real overlap state for each chunk.
168
170
self ._overlap_states : Dict [int , ISPOverlapState ] = {}
@@ -186,7 +188,9 @@ def __init__(
186
188
# key: isp module; value: transformer block index
187
189
self ._module_to_index = None
188
190
# key: transformer block index; value: isp modules
189
- self ._index_to_isp_module = None
191
+ self ._index_to_isp_modules = None
192
+ # key: transformer block index; value: transformer block
193
+ self ._index_to_block = None
190
194
191
195
# init overlap states if necessary.
192
196
if self .overlap :
@@ -228,7 +232,8 @@ def _parse_model_structure(self, cid: int, model: nn.Module) -> None:
228
232
]
229
233
230
234
for idx , block in enumerate (children ):
231
- self ._overlap_states [cid ].index_to_isp_module [idx ] = []
235
+ self ._overlap_states [cid ].index_to_isp_modules [idx ] = []
236
+ self ._overlap_states [cid ].index_to_block [idx ] = block
232
237
for sub_name , sub in block .named_children ():
233
238
for name , child in sub .named_children ():
234
239
if name in ["out_proj" , "wo" ]:
@@ -243,7 +248,7 @@ def _parse_model_structure(self, cid: int, model: nn.Module) -> None:
243
248
self ._module_shapes [name ] = torch .Size (origin_shape )
244
249
self ._overlap_states [cid ].module_to_index [child ] = idx
245
250
self ._overlap_states [cid ].isp_modules .append (child )
246
- self ._overlap_states [cid ].index_to_isp_module [idx ].append (child )
251
+ self ._overlap_states [cid ].index_to_isp_modules [idx ].append (child )
247
252
248
253
setattr (child , "isp_name" , name )
249
254
@@ -260,7 +265,7 @@ def _parse_model_structure(self, cid: int, model: nn.Module) -> None:
260
265
f"{ full_name } .bias" ,
261
266
)
262
267
263
- self ._overlap_states [cid ].num_blocks = len (self ._overlap_states [cid ].index_to_isp_module )
268
+ self ._overlap_states [cid ].num_blocks = len (self ._overlap_states [cid ].index_to_isp_modules )
264
269
265
270
def _all_gather_module_weight (self , module ):
266
271
with_bias = module .bias is not None
@@ -307,7 +312,15 @@ def _all_gather_module_weight(self, module):
307
312
self ._weight_global_output [module ] = weight_output
308
313
309
314
def _all_gather_block_weight (self , block_index : int ):
310
- for module in self ._index_to_isp_module [block_index ]:
315
+ block = self ._index_to_block [block_index ]
316
+
317
+ # wait for prerequisite conditions
318
+ if self .is_forward :
319
+ for callback in self ._forward_prefetch_prerequisites :
320
+ callback (block )
321
+
322
+ # prefetch parameters for all isp modules of the block
323
+ for module in self ._index_to_isp_modules [block_index ]:
311
324
self ._all_gather_module_weight (module )
312
325
313
326
def _wait_handle (self , module ):
@@ -358,7 +371,7 @@ def _pre_forward_hook_for_module(self, module: nn.Module, *args): # pylint: dis
358
371
self ._wait_handle (module )
359
372
360
373
def _pre_forward_hook_for_block (self , * args ): # pylint: disable=W0613
361
- for module in self ._index_to_isp_module [self ._ckpt_block_num - 1 ]:
374
+ for module in self ._index_to_isp_modules [self ._ckpt_block_num - 1 ]:
362
375
self ._all_gather_module_weight (module )
363
376
364
377
def _post_forward_hook_for_module (self , module : nn .Module , * args ): # pylint: disable=W0613
@@ -446,13 +459,41 @@ def switch_current_model_chunk(self, chunk_id: int) -> None:
446
459
self ._weight_global_output = self ._overlap_states [chunk_id ].weight_global_output
447
460
self ._bias_global_output = self ._overlap_states [chunk_id ].bias_global_output
448
461
self ._module_to_index = self ._overlap_states [chunk_id ].module_to_index
449
- self ._index_to_isp_module = self ._overlap_states [chunk_id ].index_to_isp_module
462
+ self ._index_to_isp_modules = self ._overlap_states [chunk_id ].index_to_isp_modules
463
+ self ._index_to_block = self ._overlap_states [chunk_id ].index_to_block
450
464
self ._ckpt_block_num = self ._overlap_states [chunk_id ].ckpt_block_num
451
465
self ._last_ckpt_block = self ._overlap_states [chunk_id ].last_ckpt_block
452
466
self ._head = self ._overlap_states [chunk_id ].head
453
467
self ._embedding = self ._overlap_states [chunk_id ].embedding
454
468
self ._num_blocks = self ._overlap_states [chunk_id ].num_blocks
455
469
470
+ def register_prerequisite_for_forward_prefetch_hooks (self , prerequisite_func : Callable ) -> None :
471
+ """
472
+ Registers a callback function that specifies a prerequisite condition for
473
+ prefetching parameters before forward computation.
474
+
475
+ This method allows users to define custom logic that must be satisfied before
476
+ parameters are fetched for the next forward pass. This can be useful for
477
+ implementing complex parameter update strategies or for coordinating
478
+ parameter access with other system components.
479
+
480
+ Args:
481
+ prerequisite_func (Callable): A callable that represents the prerequisite
482
+ condition. This function will be invoked before
483
+ the parameters are prefetched, and its return value
484
+ will determine whether the prefetching should proceed.
485
+
486
+ Returns:
487
+ None: This method does not return any value.
488
+
489
+ Raises:
490
+ TypeError: If the provided 'prerequisite_func' is not callable.
491
+ """
492
+ if not callable (prerequisite_func ):
493
+ raise TypeError ("The provided prerequisite function must be callable." )
494
+
495
+ self ._forward_prefetch_prerequisites .append (prerequisite_func )
496
+
456
497
# communication operation interfaces
457
498
458
499
def all_gather (self , tensor : torch .Tensor , module : nn .Module , is_bias : bool = False ):
@@ -521,8 +562,7 @@ def __init__(self, overlap_handler: ISPCommunicator, zero_optim) -> None:
521
562
self ._zero_optim = zero_optim
522
563
523
564
def before_forward (self , scheduler , inputs ) -> None :
524
- if self ._isp_communicator ._ckpt_block_num > 0 :
525
- self ._isp_communicator .is_forward = True
565
+ self ._isp_communicator .is_forward = True
526
566
# switch model chunk before forward
527
567
chunk_id = 0 if gpc .virtual_pipeline_parallel_rank is None else gpc .virtual_pipeline_parallel_rank
528
568
self ._isp_communicator .switch_current_model_chunk (chunk_id )
@@ -537,8 +577,7 @@ def after_criterion(self, scheduler, loss) -> None:
537
577
pass
538
578
539
579
def before_backward (self , scheduler , outputs , outputs_grad ) -> None :
540
- if self ._isp_communicator ._ckpt_block_num > 0 :
541
- self ._isp_communicator .is_forward = False
580
+ self ._isp_communicator .is_forward = False
542
581
# switch model chunk before backward
543
582
chunk_id = 0 if gpc .virtual_pipeline_parallel_rank is None else gpc .virtual_pipeline_parallel_rank
544
583
self ._isp_communicator .switch_current_model_chunk (chunk_id )
0 commit comments