Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[not ready] Saves checkpoints at specified steps #12847

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion nemo/lightning/resume.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,9 @@ def setup(self, trainer: Union[pl.Trainer, fl.Fabric], model=None):
raise NotImplementedError("Fabric is not supported yet.")

trainer_ckpt_path = self.get_trainer_ckpt_path(model)
# Need a way to actually restore the context.
model = _try_restore_tokenizer(model, os.path.join(trainer_ckpt_path.parent, "context"))

if trainer_ckpt_path:
trainer.ckpt_path = trainer_ckpt_path
trainer.checkpoint_callback.last_model_path = trainer_ckpt_path
Expand Down Expand Up @@ -270,7 +273,8 @@ def _find_trainer_ckpt_path(self) -> Optional[Path]:
return checkpoint

def get_context_path(self, model: Optional[io.ConnectorMixin] = None) -> Optional[Path]:
checkpoint = None

checkpoint = None # ??? this is totally wrong.
app_state = AppState()
app_state.restore = self.resume_if_exists
if self.resume_if_exists:
Expand Down
85 changes: 84 additions & 1 deletion nemo/utils/exp_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@

from nemo.collections.common.callbacks import EMA
from nemo.constants import NEMO_ENV_VARNAME_TESTING, NEMO_ENV_VARNAME_VERSION
from nemo.lightning.ckpt_utils import ckpt_to_context_subdir

Check notice

Code scanning / CodeQL

Cyclic import Note

Import of module
nemo.lightning.ckpt_utils
begins an import cycle.

Copilot Autofix

AI 11 days ago

Copilot could not generate an autofix suggestion

Copilot could not generate an autofix suggestion for this alert. Try pushing a new commit or if the problem persists contact support.

from nemo.utils import logging, timers
from nemo.utils.app_state import AppState
from nemo.utils.callbacks import NeMoModelCheckpoint, PreemptionCallback
Expand Down Expand Up @@ -248,7 +249,7 @@
# Configures creation of log files for different ranks
log_local_rank_0_only: Optional[bool] = False
log_global_rank_0_only: Optional[bool] = False
# disable initial validation when resuming from a checkpoint saved during validation
# disable initial validation when resuming from a checkpoint
disable_validation_on_resume: Optional[bool] = True
ema: Optional[EMAParams] = field(default_factory=lambda: EMAParams())
# Wall clock time limit
Expand All @@ -265,6 +266,88 @@
log_tflops_per_sec_per_gpu: Optional[bool] = True


try:
from nemo.lightning.io.pl import TrainerContext

Check notice

Code scanning / CodeQL

Cyclic import Note

Import of module
nemo.lightning.io.pl
begins an import cycle.

Copilot Autofix

AI 11 days ago

To fix the cyclic import issue, we need to break the cycle by removing the import of TrainerContext from nemo.lightning.io.pl in the nemo/utils/exp_manager.py file. Since the TrainerContext is only used conditionally, we can move the import statement inside the try-except block where it is used. This way, the import will only be executed when needed, and it will not create a cyclic dependency during the initial module loading.

Suggested changeset 1
nemo/utils/exp_manager.py

Autofix patch

Autofix patch
Run the following command in your local git repository to apply this patch
cat << 'EOF' | git apply
diff --git a/nemo/utils/exp_manager.py b/nemo/utils/exp_manager.py
--- a/nemo/utils/exp_manager.py
+++ b/nemo/utils/exp_manager.py
@@ -269,4 +269,5 @@
 try:
-    from nemo.lightning.io.pl import TrainerContext
+    import nemo.lightning.io.pl as pl_module
 
+    TrainerContext = pl_module.TrainerContext
     HAVE_TRAINER_CONTEXT = True
EOF
@@ -269,4 +269,5 @@
try:
from nemo.lightning.io.pl import TrainerContext
import nemo.lightning.io.pl as pl_module

TrainerContext = pl_module.TrainerContext
HAVE_TRAINER_CONTEXT = True
Copilot is powered by AI and may make mistakes. Always verify output.

HAVE_TRAINER_CONTEXT = True
except ImportError:
logging.warning("[Callback] Cannot import TrainerContext. Will not save extra context information.")
HAVE_TRAINER_CONTEXT = False


class SaveAtStepsCallback(Callback):
"""
Callback to save PTL checkpoints (.ckpt) and associated TrainerContext files
into a subdirectory at specific global step milestones.
Ensures saving only happens on global rank 0.
"""

def __init__(
self, save_steps: list[int], save_path: str, filename_prefix: str = "model_step", save_context: bool = True
):
"""
Args:
save_steps (list[int]): A list or tuple of global steps at which to save checkpoints.
save_path (str): The base directory where checkpoints and context subdirs will be saved.
This path should exist or be creatable by rank 0.
filename_prefix (str): Prefix for the checkpoint filename.
The final name will be f"{filename_prefix}_{global_step}.ckpt".
save_context (bool): Whether to also save the TrainerContext alongside the .ckpt file.
Defaults to True. Requires TrainerContext to be importable.
"""
super().__init__()
self.save_steps = set(save_steps)
# Ensure save_path is a Path object for consistency
self.save_path = Path(save_path)
self.filename_prefix = filename_prefix
self._saved_steps_in_run = set()
self.save_context = save_context and HAVE_TRAINER_CONTEXT # Control via init

# Create the save directory if it doesn't exist (only on rank 0)
if is_global_rank_zero():
# Use pathlib's mkdir
self.save_path.mkdir(parents=True, exist_ok=True)

# Ensure all ranks wait for rank 0 to potentially create the directory
if torch.distributed.is_initialized():
torch.distributed.barrier()

def on_train_batch_end(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", outputs, batch, batch_idx: int
) -> None:
"""Check if the current global step is one of the specified steps to save."""

if not is_global_rank_zero():
return

global_step = trainer.global_step

if global_step in self.save_steps and global_step not in self._saved_steps_in_run:
filename = f"{self.filename_prefix}_{global_step}.ckpt"
filepath = self.save_path / filename # Use pathlib join

logging.info(f"[Callback] Saving checkpoint at global step {global_step} to {str(filepath)}")
try:
# 1. Save the .ckpt file using trainer
trainer.save_checkpoint(filepath) # Pass Path object directly

# 2. Optionally save the TrainerContext into its subdirectory
if self.save_context:
# Use ckpt_to_context_subdir to get the target directory path
context_target_dir = ckpt_to_context_subdir(filepath)
logging.info(f"[Callback] Saving TrainerContext to {str(context_target_dir)}")
# Pass the calculated subdirectory path directly to io_dump
TrainerContext.from_trainer(trainer).io_dump(context_target_dir, yaml_attrs=["model"])
logging.info(f"[Callback] Successfully saved TrainerContext in: {str(context_target_dir)}")

# Mark step as saved only if all parts succeed
self._saved_steps_in_run.add(global_step)
logging.info(f"[Callback] Successfully completed saving for step: {global_step}")

except Exception as e:
logging.error(f"[Callback] Failed to save checkpoint or context at step {global_step}: {e}")


class TimingCallback(Callback):
"""
Logs execution time of train/val/test steps
Expand Down
Loading