Skip to content

Commit 5146da4

Browse files
committed
Saves checkpoints at specified steps
1 parent cc8ff45 commit 5146da4

File tree

1 file changed

+82
-1
lines changed

1 file changed

+82
-1
lines changed

nemo/utils/exp_manager.py

+82-1
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
from lightning.pytorch.trainer.connectors.checkpoint_connector import _CheckpointConnector
4040
from omegaconf import DictConfig, OmegaConf, open_dict
4141

42+
from nemo.lightning.ckpt_utils import ckpt_to_context_subdir
4243
from nemo.collections.common.callbacks import EMA
4344
from nemo.constants import NEMO_ENV_VARNAME_TESTING, NEMO_ENV_VARNAME_VERSION
4445
from nemo.utils import logging, timers
@@ -248,7 +249,7 @@ class ExpManagerConfig:
248249
# Configures creation of log files for different ranks
249250
log_local_rank_0_only: Optional[bool] = False
250251
log_global_rank_0_only: Optional[bool] = False
251-
# disable initial validation when resuming from a checkpoint saved during validation
252+
# disable initial validation when resuming from a checkpoint
252253
disable_validation_on_resume: Optional[bool] = True
253254
ema: Optional[EMAParams] = field(default_factory=lambda: EMAParams())
254255
# Wall clock time limit
@@ -265,6 +266,84 @@ class ExpManagerConfig:
265266
log_tflops_per_sec_per_gpu: Optional[bool] = True
266267

267268

269+
try:
270+
from nemo.lightning.io.pl import TrainerContext
271+
HAVE_TRAINER_CONTEXT = True
272+
except ImportError:
273+
logging.warning("[Callback] Cannot import TrainerContext. Will not save extra context information.")
274+
HAVE_TRAINER_CONTEXT = False
275+
276+
277+
class SaveAtStepsCallback(Callback):
278+
"""
279+
Callback to save PTL checkpoints (.ckpt) and associated TrainerContext files
280+
into a subdirectory at specific global step milestones.
281+
Ensures saving only happens on global rank 0.
282+
"""
283+
def __init__(self, save_steps: list[int], save_path: str, filename_prefix: str = "model_step", save_context: bool = True):
284+
"""
285+
Args:
286+
save_steps (list[int]): A list or tuple of global steps at which to save checkpoints.
287+
save_path (str): The base directory where checkpoints and context subdirs will be saved.
288+
This path should exist or be creatable by rank 0.
289+
filename_prefix (str): Prefix for the checkpoint filename.
290+
The final name will be f"{filename_prefix}_{global_step}.ckpt".
291+
save_context (bool): Whether to also save the TrainerContext alongside the .ckpt file.
292+
Defaults to True. Requires TrainerContext to be importable.
293+
"""
294+
super().__init__()
295+
self.save_steps = set(save_steps)
296+
# Ensure save_path is a Path object for consistency
297+
self.save_path = Path(save_path)
298+
self.filename_prefix = filename_prefix
299+
self._saved_steps_in_run = set()
300+
self.save_context = save_context and HAVE_TRAINER_CONTEXT # Control via init
301+
302+
# Create the save directory if it doesn't exist (only on rank 0)
303+
if is_global_rank_zero():
304+
# Use pathlib's mkdir
305+
self.save_path.mkdir(parents=True, exist_ok=True)
306+
307+
# Ensure all ranks wait for rank 0 to potentially create the directory
308+
if torch.distributed.is_initialized():
309+
torch.distributed.barrier()
310+
311+
def on_train_batch_end(
312+
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", outputs, batch, batch_idx: int
313+
) -> None:
314+
"""Check if the current global step is one of the specified steps to save."""
315+
316+
if not is_global_rank_zero():
317+
return
318+
319+
global_step = trainer.global_step
320+
321+
if global_step in self.save_steps and global_step not in self._saved_steps_in_run:
322+
filename = f"{self.filename_prefix}_{global_step}.ckpt"
323+
filepath = self.save_path / filename # Use pathlib join
324+
325+
logging.info(f"[Callback] Saving checkpoint at global step {global_step} to {str(filepath)}")
326+
try:
327+
# 1. Save the .ckpt file using trainer
328+
trainer.save_checkpoint(filepath) # Pass Path object directly
329+
330+
# 2. Optionally save the TrainerContext into its subdirectory
331+
if self.save_context:
332+
# Use ckpt_to_context_subdir to get the target directory path
333+
context_target_dir = ckpt_to_context_subdir(filepath)
334+
logging.info(f"[Callback] Saving TrainerContext to {str(context_target_dir)}")
335+
# Pass the calculated subdirectory path directly to io_dump
336+
TrainerContext.from_trainer(trainer).io_dump(context_target_dir, yaml_attrs=["model"])
337+
logging.info(f"[Callback] Successfully saved TrainerContext in: {str(context_target_dir)}")
338+
339+
# Mark step as saved only if all parts succeed
340+
self._saved_steps_in_run.add(global_step)
341+
logging.info(f"[Callback] Successfully completed saving for step: {global_step}")
342+
343+
except Exception as e:
344+
logging.error(f"[Callback] Failed to save checkpoint or context at step {global_step}: {e}")
345+
346+
268347
class TimingCallback(Callback):
269348
"""
270349
Logs execution time of train/val/test steps
@@ -377,6 +456,8 @@ def on_after_backward(self, trainer, pl_module):
377456
self._on_batch_end("train_backward_timing", pl_module)
378457

379458

459+
460+
380461
class DeltaTimingCallback(Callback):
381462
"""
382463
Logs execution time of train/val/test steps using nemo logger. Calculates

0 commit comments

Comments
 (0)