39
39
from lightning .pytorch .trainer .connectors .checkpoint_connector import _CheckpointConnector
40
40
from omegaconf import DictConfig , OmegaConf , open_dict
41
41
42
+ from nemo .lightning .ckpt_utils import ckpt_to_context_subdir
42
43
from nemo .collections .common .callbacks import EMA
43
44
from nemo .constants import NEMO_ENV_VARNAME_TESTING , NEMO_ENV_VARNAME_VERSION
44
45
from nemo .utils import logging , timers
@@ -248,7 +249,7 @@ class ExpManagerConfig:
248
249
# Configures creation of log files for different ranks
249
250
log_local_rank_0_only : Optional [bool ] = False
250
251
log_global_rank_0_only : Optional [bool ] = False
251
- # disable initial validation when resuming from a checkpoint saved during validation
252
+ # disable initial validation when resuming from a checkpoint
252
253
disable_validation_on_resume : Optional [bool ] = True
253
254
ema : Optional [EMAParams ] = field (default_factory = lambda : EMAParams ())
254
255
# Wall clock time limit
@@ -265,6 +266,84 @@ class ExpManagerConfig:
265
266
log_tflops_per_sec_per_gpu : Optional [bool ] = True
266
267
267
268
269
+ try :
270
+ from nemo .lightning .io .pl import TrainerContext
271
+ HAVE_TRAINER_CONTEXT = True
272
+ except ImportError :
273
+ logging .warning ("[Callback] Cannot import TrainerContext. Will not save extra context information." )
274
+ HAVE_TRAINER_CONTEXT = False
275
+
276
+
277
+ class SaveAtStepsCallback (Callback ):
278
+ """
279
+ Callback to save PTL checkpoints (.ckpt) and associated TrainerContext files
280
+ into a subdirectory at specific global step milestones.
281
+ Ensures saving only happens on global rank 0.
282
+ """
283
+ def __init__ (self , save_steps : list [int ], save_path : str , filename_prefix : str = "model_step" , save_context : bool = True ):
284
+ """
285
+ Args:
286
+ save_steps (list[int]): A list or tuple of global steps at which to save checkpoints.
287
+ save_path (str): The base directory where checkpoints and context subdirs will be saved.
288
+ This path should exist or be creatable by rank 0.
289
+ filename_prefix (str): Prefix for the checkpoint filename.
290
+ The final name will be f"{filename_prefix}_{global_step}.ckpt".
291
+ save_context (bool): Whether to also save the TrainerContext alongside the .ckpt file.
292
+ Defaults to True. Requires TrainerContext to be importable.
293
+ """
294
+ super ().__init__ ()
295
+ self .save_steps = set (save_steps )
296
+ # Ensure save_path is a Path object for consistency
297
+ self .save_path = Path (save_path )
298
+ self .filename_prefix = filename_prefix
299
+ self ._saved_steps_in_run = set ()
300
+ self .save_context = save_context and HAVE_TRAINER_CONTEXT # Control via init
301
+
302
+ # Create the save directory if it doesn't exist (only on rank 0)
303
+ if is_global_rank_zero ():
304
+ # Use pathlib's mkdir
305
+ self .save_path .mkdir (parents = True , exist_ok = True )
306
+
307
+ # Ensure all ranks wait for rank 0 to potentially create the directory
308
+ if torch .distributed .is_initialized ():
309
+ torch .distributed .barrier ()
310
+
311
+ def on_train_batch_end (
312
+ self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" , outputs , batch , batch_idx : int
313
+ ) -> None :
314
+ """Check if the current global step is one of the specified steps to save."""
315
+
316
+ if not is_global_rank_zero ():
317
+ return
318
+
319
+ global_step = trainer .global_step
320
+
321
+ if global_step in self .save_steps and global_step not in self ._saved_steps_in_run :
322
+ filename = f"{ self .filename_prefix } _{ global_step } .ckpt"
323
+ filepath = self .save_path / filename # Use pathlib join
324
+
325
+ logging .info (f"[Callback] Saving checkpoint at global step { global_step } to { str (filepath )} " )
326
+ try :
327
+ # 1. Save the .ckpt file using trainer
328
+ trainer .save_checkpoint (filepath ) # Pass Path object directly
329
+
330
+ # 2. Optionally save the TrainerContext into its subdirectory
331
+ if self .save_context :
332
+ # Use ckpt_to_context_subdir to get the target directory path
333
+ context_target_dir = ckpt_to_context_subdir (filepath )
334
+ logging .info (f"[Callback] Saving TrainerContext to { str (context_target_dir )} " )
335
+ # Pass the calculated subdirectory path directly to io_dump
336
+ TrainerContext .from_trainer (trainer ).io_dump (context_target_dir , yaml_attrs = ["model" ])
337
+ logging .info (f"[Callback] Successfully saved TrainerContext in: { str (context_target_dir )} " )
338
+
339
+ # Mark step as saved only if all parts succeed
340
+ self ._saved_steps_in_run .add (global_step )
341
+ logging .info (f"[Callback] Successfully completed saving for step: { global_step } " )
342
+
343
+ except Exception as e :
344
+ logging .error (f"[Callback] Failed to save checkpoint or context at step { global_step } : { e } " )
345
+
346
+
268
347
class TimingCallback (Callback ):
269
348
"""
270
349
Logs execution time of train/val/test steps
@@ -377,6 +456,8 @@ def on_after_backward(self, trainer, pl_module):
377
456
self ._on_batch_end ("train_backward_timing" , pl_module )
378
457
379
458
459
+
460
+
380
461
class DeltaTimingCallback (Callback ):
381
462
"""
382
463
Logs execution time of train/val/test steps using nemo logger. Calculates
0 commit comments