39
39
from lightning .pytorch .trainer .connectors .checkpoint_connector import _CheckpointConnector
40
40
from omegaconf import DictConfig , OmegaConf , open_dict
41
41
42
- from nemo .lightning .ckpt_utils import ckpt_to_context_subdir
43
42
from nemo .collections .common .callbacks import EMA
44
43
from nemo .constants import NEMO_ENV_VARNAME_TESTING , NEMO_ENV_VARNAME_VERSION
44
+ from nemo .lightning .ckpt_utils import ckpt_to_context_subdir
45
45
from nemo .utils import logging , timers
46
46
from nemo .utils .app_state import AppState
47
47
from nemo .utils .callbacks import NeMoModelCheckpoint , PreemptionCallback
@@ -268,6 +268,7 @@ class ExpManagerConfig:
268
268
269
269
try :
270
270
from nemo .lightning .io .pl import TrainerContext
271
+
271
272
HAVE_TRAINER_CONTEXT = True
272
273
except ImportError :
273
274
logging .warning ("[Callback] Cannot import TrainerContext. Will not save extra context information." )
@@ -280,7 +281,10 @@ class SaveAtStepsCallback(Callback):
280
281
into a subdirectory at specific global step milestones.
281
282
Ensures saving only happens on global rank 0.
282
283
"""
283
- def __init__ (self , save_steps : list [int ], save_path : str , filename_prefix : str = "model_step" , save_context : bool = True ):
284
+
285
+ def __init__ (
286
+ self , save_steps : list [int ], save_path : str , filename_prefix : str = "model_step" , save_context : bool = True
287
+ ):
284
288
"""
285
289
Args:
286
290
save_steps (list[int]): A list or tuple of global steps at which to save checkpoints.
@@ -297,7 +301,7 @@ def __init__(self, save_steps: list[int], save_path: str, filename_prefix: str =
297
301
self .save_path = Path (save_path )
298
302
self .filename_prefix = filename_prefix
299
303
self ._saved_steps_in_run = set ()
300
- self .save_context = save_context and HAVE_TRAINER_CONTEXT # Control via init
304
+ self .save_context = save_context and HAVE_TRAINER_CONTEXT # Control via init
301
305
302
306
# Create the save directory if it doesn't exist (only on rank 0)
303
307
if is_global_rank_zero ():
@@ -320,12 +324,12 @@ def on_train_batch_end(
320
324
321
325
if global_step in self .save_steps and global_step not in self ._saved_steps_in_run :
322
326
filename = f"{ self .filename_prefix } _{ global_step } .ckpt"
323
- filepath = self .save_path / filename # Use pathlib join
327
+ filepath = self .save_path / filename # Use pathlib join
324
328
325
329
logging .info (f"[Callback] Saving checkpoint at global step { global_step } to { str (filepath )} " )
326
330
try :
327
331
# 1. Save the .ckpt file using trainer
328
- trainer .save_checkpoint (filepath ) # Pass Path object directly
332
+ trainer .save_checkpoint (filepath ) # Pass Path object directly
329
333
330
334
# 2. Optionally save the TrainerContext into its subdirectory
331
335
if self .save_context :
@@ -456,8 +460,6 @@ def on_after_backward(self, trainer, pl_module):
456
460
self ._on_batch_end ("train_backward_timing" , pl_module )
457
461
458
462
459
-
460
-
461
463
class DeltaTimingCallback (Callback ):
462
464
"""
463
465
Logs execution time of train/val/test steps using nemo logger. Calculates
0 commit comments