Skip to content

Commit 3807129

Browse files
committed
Apply isort and black reformatting
Signed-off-by: jomitchellnv <[email protected]>
1 parent 5146da4 commit 3807129

File tree

1 file changed

+9
-7
lines changed

1 file changed

+9
-7
lines changed

nemo/utils/exp_manager.py

+9-7
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,9 @@
3939
from lightning.pytorch.trainer.connectors.checkpoint_connector import _CheckpointConnector
4040
from omegaconf import DictConfig, OmegaConf, open_dict
4141

42-
from nemo.lightning.ckpt_utils import ckpt_to_context_subdir
4342
from nemo.collections.common.callbacks import EMA
4443
from nemo.constants import NEMO_ENV_VARNAME_TESTING, NEMO_ENV_VARNAME_VERSION
44+
from nemo.lightning.ckpt_utils import ckpt_to_context_subdir
4545
from nemo.utils import logging, timers
4646
from nemo.utils.app_state import AppState
4747
from nemo.utils.callbacks import NeMoModelCheckpoint, PreemptionCallback
@@ -268,6 +268,7 @@ class ExpManagerConfig:
268268

269269
try:
270270
from nemo.lightning.io.pl import TrainerContext
271+
271272
HAVE_TRAINER_CONTEXT = True
272273
except ImportError:
273274
logging.warning("[Callback] Cannot import TrainerContext. Will not save extra context information.")
@@ -280,7 +281,10 @@ class SaveAtStepsCallback(Callback):
280281
into a subdirectory at specific global step milestones.
281282
Ensures saving only happens on global rank 0.
282283
"""
283-
def __init__(self, save_steps: list[int], save_path: str, filename_prefix: str = "model_step", save_context: bool = True):
284+
285+
def __init__(
286+
self, save_steps: list[int], save_path: str, filename_prefix: str = "model_step", save_context: bool = True
287+
):
284288
"""
285289
Args:
286290
save_steps (list[int]): A list or tuple of global steps at which to save checkpoints.
@@ -297,7 +301,7 @@ def __init__(self, save_steps: list[int], save_path: str, filename_prefix: str =
297301
self.save_path = Path(save_path)
298302
self.filename_prefix = filename_prefix
299303
self._saved_steps_in_run = set()
300-
self.save_context = save_context and HAVE_TRAINER_CONTEXT # Control via init
304+
self.save_context = save_context and HAVE_TRAINER_CONTEXT # Control via init
301305

302306
# Create the save directory if it doesn't exist (only on rank 0)
303307
if is_global_rank_zero():
@@ -320,12 +324,12 @@ def on_train_batch_end(
320324

321325
if global_step in self.save_steps and global_step not in self._saved_steps_in_run:
322326
filename = f"{self.filename_prefix}_{global_step}.ckpt"
323-
filepath = self.save_path / filename # Use pathlib join
327+
filepath = self.save_path / filename # Use pathlib join
324328

325329
logging.info(f"[Callback] Saving checkpoint at global step {global_step} to {str(filepath)}")
326330
try:
327331
# 1. Save the .ckpt file using trainer
328-
trainer.save_checkpoint(filepath) # Pass Path object directly
332+
trainer.save_checkpoint(filepath) # Pass Path object directly
329333

330334
# 2. Optionally save the TrainerContext into its subdirectory
331335
if self.save_context:
@@ -456,8 +460,6 @@ def on_after_backward(self, trainer, pl_module):
456460
self._on_batch_end("train_backward_timing", pl_module)
457461

458462

459-
460-
461463
class DeltaTimingCallback(Callback):
462464
"""
463465
Logs execution time of train/val/test steps using nemo logger. Calculates

0 commit comments

Comments
 (0)