diff --git a/nemo/lightning/resume.py b/nemo/lightning/resume.py index cfb73ca1a861..8825cf957785 100644 --- a/nemo/lightning/resume.py +++ b/nemo/lightning/resume.py @@ -104,6 +104,9 @@ def setup(self, trainer: Union[pl.Trainer, fl.Fabric], model=None): raise NotImplementedError("Fabric is not supported yet.") trainer_ckpt_path = self.get_trainer_ckpt_path(model) + # Need a way to actually restore the context. + model = _try_restore_tokenizer(model, os.path.join(trainer_ckpt_path.parent, "context")) + if trainer_ckpt_path: trainer.ckpt_path = trainer_ckpt_path trainer.checkpoint_callback.last_model_path = trainer_ckpt_path @@ -270,7 +273,8 @@ def _find_trainer_ckpt_path(self) -> Optional[Path]: return checkpoint def get_context_path(self, model: Optional[io.ConnectorMixin] = None) -> Optional[Path]: - checkpoint = None + + checkpoint = None # ??? this is totally wrong. app_state = AppState() app_state.restore = self.resume_if_exists if self.resume_if_exists: diff --git a/nemo/utils/exp_manager.py b/nemo/utils/exp_manager.py index 2d82ecc9e110..81820f2a069d 100644 --- a/nemo/utils/exp_manager.py +++ b/nemo/utils/exp_manager.py @@ -41,6 +41,7 @@ from nemo.collections.common.callbacks import EMA from nemo.constants import NEMO_ENV_VARNAME_TESTING, NEMO_ENV_VARNAME_VERSION +from nemo.lightning.ckpt_utils import ckpt_to_context_subdir from nemo.utils import logging, timers from nemo.utils.app_state import AppState from nemo.utils.callbacks import NeMoModelCheckpoint, PreemptionCallback @@ -248,7 +249,7 @@ class ExpManagerConfig: # Configures creation of log files for different ranks log_local_rank_0_only: Optional[bool] = False log_global_rank_0_only: Optional[bool] = False - # disable initial validation when resuming from a checkpoint saved during validation + # disable initial validation when resuming from a checkpoint disable_validation_on_resume: Optional[bool] = True ema: Optional[EMAParams] = field(default_factory=lambda: EMAParams()) # Wall clock time limit @@ -265,6 +266,88 @@ class ExpManagerConfig: log_tflops_per_sec_per_gpu: Optional[bool] = True +try: + from nemo.lightning.io.pl import TrainerContext + + HAVE_TRAINER_CONTEXT = True +except ImportError: + logging.warning("[Callback] Cannot import TrainerContext. Will not save extra context information.") + HAVE_TRAINER_CONTEXT = False + + +class SaveAtStepsCallback(Callback): + """ + Callback to save PTL checkpoints (.ckpt) and associated TrainerContext files + into a subdirectory at specific global step milestones. + Ensures saving only happens on global rank 0. + """ + + def __init__( + self, save_steps: list[int], save_path: str, filename_prefix: str = "model_step", save_context: bool = True + ): + """ + Args: + save_steps (list[int]): A list or tuple of global steps at which to save checkpoints. + save_path (str): The base directory where checkpoints and context subdirs will be saved. + This path should exist or be creatable by rank 0. + filename_prefix (str): Prefix for the checkpoint filename. + The final name will be f"{filename_prefix}_{global_step}.ckpt". + save_context (bool): Whether to also save the TrainerContext alongside the .ckpt file. + Defaults to True. Requires TrainerContext to be importable. + """ + super().__init__() + self.save_steps = set(save_steps) + # Ensure save_path is a Path object for consistency + self.save_path = Path(save_path) + self.filename_prefix = filename_prefix + self._saved_steps_in_run = set() + self.save_context = save_context and HAVE_TRAINER_CONTEXT # Control via init + + # Create the save directory if it doesn't exist (only on rank 0) + if is_global_rank_zero(): + # Use pathlib's mkdir + self.save_path.mkdir(parents=True, exist_ok=True) + + # Ensure all ranks wait for rank 0 to potentially create the directory + if torch.distributed.is_initialized(): + torch.distributed.barrier() + + def on_train_batch_end( + self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", outputs, batch, batch_idx: int + ) -> None: + """Check if the current global step is one of the specified steps to save.""" + + if not is_global_rank_zero(): + return + + global_step = trainer.global_step + + if global_step in self.save_steps and global_step not in self._saved_steps_in_run: + filename = f"{self.filename_prefix}_{global_step}.ckpt" + filepath = self.save_path / filename # Use pathlib join + + logging.info(f"[Callback] Saving checkpoint at global step {global_step} to {str(filepath)}") + try: + # 1. Save the .ckpt file using trainer + trainer.save_checkpoint(filepath) # Pass Path object directly + + # 2. Optionally save the TrainerContext into its subdirectory + if self.save_context: + # Use ckpt_to_context_subdir to get the target directory path + context_target_dir = ckpt_to_context_subdir(filepath) + logging.info(f"[Callback] Saving TrainerContext to {str(context_target_dir)}") + # Pass the calculated subdirectory path directly to io_dump + TrainerContext.from_trainer(trainer).io_dump(context_target_dir, yaml_attrs=["model"]) + logging.info(f"[Callback] Successfully saved TrainerContext in: {str(context_target_dir)}") + + # Mark step as saved only if all parts succeed + self._saved_steps_in_run.add(global_step) + logging.info(f"[Callback] Successfully completed saving for step: {global_step}") + + except Exception as e: + logging.error(f"[Callback] Failed to save checkpoint or context at step {global_step}: {e}") + + class TimingCallback(Callback): """ Logs execution time of train/val/test steps