Skip to content

Commit 558e0a9

Browse files
gitttt-1234claude
andauthored
Handle empty and "None" string values for run_name and ckpt_dir config parameters (#377)
## Summary This PR adds defensive checks for empty and "None" string values in `run_name` and `ckpt_dir` config parameters, along with formatting improvements across the codebase. ## Changes ### Functional Changes - **ModelTrainer._setup_ckpt_path()**: Added handling for empty string (`""`) and string literal `"None"` for both `run_name` and `ckpt_dir` parameters - Prevents unexpected behavior when YAML configs have empty values (e.g., `run_name:` or `ckpt_dir:`) - Handles edge case where users accidentally set string `"None"` instead of null ### Test Coverage - Added test case for empty `run_name` in `test_model_ckpt_path_duplication` ### Code Formatting - Removed extra blank lines across multiple files - Added consistent spacing around operators in f-strings - Properly wrapped tuple assignments in lightning_modules.py - Fixed string concatenation in assertion message (inference/utils.py) ## Files Modified - `sleap_nn/training/model_trainer.py` - defensive checks and formatting - `sleap_nn/architectures/encoder_decoder.py` - formatting - `sleap_nn/architectures/unet.py` - formatting - `sleap_nn/inference/predictors.py` - formatting - `sleap_nn/inference/topdown.py` - formatting - `sleap_nn/inference/utils.py` - string formatting fix - `sleap_nn/tracking/candidates/fixed_window.py` - formatting - `sleap_nn/tracking/utils.py` - formatting - `sleap_nn/training/lightning_modules.py` - tuple assignment formatting - `tests/training/test_model_trainer.py` - added test coverage ## Testing - Existing tests should pass with these changes - Added specific test for empty run_name case 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]> Co-authored-by: Claude <[email protected]>
1 parent 1d46d08 commit 558e0a9

File tree

10 files changed

+47
-33
lines changed

10 files changed

+47
-33
lines changed

sleap_nn/architectures/encoder_decoder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -205,8 +205,8 @@ def __init__(
205205

206206
# Always finish with a pooling block to account for pooling before convs.
207207
final_pool_dict = OrderedDict()
208-
final_pool_dict[f"{self.prefix}{block+1}_last_pool"] = MaxPool2dWithSamePadding(
209-
kernel_size=2, stride=2, padding="same"
208+
final_pool_dict[f"{self.prefix}{block + 1}_last_pool"] = (
209+
MaxPool2dWithSamePadding(kernel_size=2, stride=2, padding="same")
210210
)
211211
self.stem_stack.append(nn.Sequential(final_pool_dict))
212212

sleap_nn/architectures/unet.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,6 @@ def __init__(
124124
)
125125
enc_num = len(encoder.encoder_stack)
126126
if self.middle_block:
127-
128127
if convs_per_block > 1:
129128
# Middle expansion block
130129
from sleap_nn.architectures.encoder_decoder import SimpleConvBlock

sleap_nn/inference/predictors.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -363,7 +363,6 @@ def _predict_generator(self) -> Iterator[Dict[str, np.ndarray]]:
363363
done = False
364364

365365
try:
366-
367366
with Progress(
368367
"{task.description}",
369368
BarColumn(),
@@ -378,7 +377,6 @@ def _predict_generator(self) -> Iterator[Dict[str, np.ndarray]]:
378377
refresh_per_second=4, # Change to self.report_rate if needed
379378
speed_estimate_period=5,
380379
) as progress:
381-
382380
task = progress.add_task("Predicting...", total=total_frames)
383381
last_report = time()
384382

@@ -660,7 +658,6 @@ def _initialize_inference_model(self):
660658
instance_peaks_layer = FindInstancePeaksGroundTruth()
661659
self.instances_key = True
662660
else:
663-
664661
max_stride = self.confmap_config.model_config.backbone_config[
665662
f"{self.centered_instance_backbone_type}"
666663
]["max_stride"]
@@ -1604,7 +1601,6 @@ def _make_labeled_frames_from_generator(
16041601
ex["pred_peak_values"],
16051602
ex["orig_size"],
16061603
):
1607-
16081604
if np.isnan(pred_instances).all():
16091605
continue
16101606
inst = sio.PredictedInstance.from_numpy(
@@ -2046,7 +2042,6 @@ def _make_labeled_frames_from_generator(
20462042
ex["pred_peak_values"],
20472043
ex["instance_scores"],
20482044
):
2049-
20502045
# Loop over instances.
20512046
predicted_instances = []
20522047
for pts, confs, score in zip(
@@ -2488,7 +2483,6 @@ def _make_labeled_frames_from_generator(
24882483
ex["pred_peak_values"],
24892484
ex["instance_scores"],
24902485
):
2491-
24922486
# Loop over instances.
24932487
predicted_instances = []
24942488
for i, (pts, confs, score) in enumerate(

sleap_nn/inference/topdown.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -799,7 +799,6 @@ def forward(self, batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
799799
batch = self.centroid_crop(batch)
800800

801801
if batch is not None:
802-
803802
if isinstance(self.instance_peaks, FindInstancePeaksGroundTruth):
804803
peaks_output.append(self.instance_peaks(batch))
805804
else:

sleap_nn/inference/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def interp1d(x: torch.Tensor, y: torch.Tensor, xnew: torch.Tensor) -> torch.Tens
4747
v = {}
4848
eps = torch.finfo(y.dtype).eps
4949
for name, vec in {"x": x, "y": y, "xnew": xnew}.items():
50-
assert len(vec.shape) <= 2, "interp1d: all inputs must be " "at most 2-D."
50+
assert len(vec.shape) <= 2, "interp1d: all inputs must be at most 2-D."
5151
if len(vec.shape) == 1:
5252
v[name] = vec[None, :]
5353
else:

sleap_nn/tracking/candidates/fixed_window.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,6 @@ def update_tracks(
134134
"""
135135
add_to_queue = True
136136
if row_inds is not None and col_inds is not None:
137-
138137
for idx, (row, col) in enumerate(zip(row_inds, col_inds)):
139138
current_instances.track_ids[row] = self.current_tracks[col]
140139
current_instances.tracking_scores[row] = tracking_scores[idx]

sleap_nn/tracking/utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,6 @@ def nms_fast(boxes, scores, iou_threshold, target_count=None) -> List[int]:
133133

134134
# keep looping while some indexes still remain in the indexes list
135135
while len(idxs) > 0:
136-
137136
# we want to add the best box which is the last box in sorted list
138137
picked_box_idx = idxs[-1]
139138

sleap_nn/training/lightning_modules.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -527,8 +527,9 @@ def forward(self, img):
527527

528528
def training_step(self, batch, batch_idx):
529529
"""Training step."""
530-
X, y = torch.squeeze(batch["image"], dim=1), torch.squeeze(
531-
batch["confidence_maps"], dim=1
530+
X, y = (
531+
torch.squeeze(batch["image"], dim=1),
532+
torch.squeeze(batch["confidence_maps"], dim=1),
532533
)
533534

534535
y_preds = self.model(X)["SingleInstanceConfmapsHead"]
@@ -574,8 +575,9 @@ def training_step(self, batch, batch_idx):
574575

575576
def validation_step(self, batch, batch_idx):
576577
"""Validation step."""
577-
X, y = torch.squeeze(batch["image"], dim=1), torch.squeeze(
578-
batch["confidence_maps"], dim=1
578+
X, y = (
579+
torch.squeeze(batch["image"], dim=1),
580+
torch.squeeze(batch["confidence_maps"], dim=1),
579581
)
580582

581583
y_preds = self.model(X)["SingleInstanceConfmapsHead"]
@@ -737,8 +739,9 @@ def forward(self, img):
737739

738740
def training_step(self, batch, batch_idx):
739741
"""Training step."""
740-
X, y = torch.squeeze(batch["instance_image"], dim=1), torch.squeeze(
741-
batch["confidence_maps"], dim=1
742+
X, y = (
743+
torch.squeeze(batch["instance_image"], dim=1),
744+
torch.squeeze(batch["confidence_maps"], dim=1),
742745
)
743746

744747
y_preds = self.model(X)["CenteredInstanceConfmapsHead"]
@@ -785,8 +788,9 @@ def training_step(self, batch, batch_idx):
785788

786789
def validation_step(self, batch, batch_idx):
787790
"""Perform validation step."""
788-
X, y = torch.squeeze(batch["instance_image"], dim=1), torch.squeeze(
789-
batch["confidence_maps"], dim=1
791+
X, y = (
792+
torch.squeeze(batch["instance_image"], dim=1),
793+
torch.squeeze(batch["confidence_maps"], dim=1),
790794
)
791795

792796
y_preds = self.model(X)["CenteredInstanceConfmapsHead"]
@@ -947,8 +951,9 @@ def forward(self, img):
947951

948952
def training_step(self, batch, batch_idx):
949953
"""Training step."""
950-
X, y = torch.squeeze(batch["image"], dim=1), torch.squeeze(
951-
batch["centroids_confidence_maps"], dim=1
954+
X, y = (
955+
torch.squeeze(batch["image"], dim=1),
956+
torch.squeeze(batch["centroids_confidence_maps"], dim=1),
952957
)
953958

954959
y_preds = self.model(X)["CentroidConfmapsHead"]
@@ -966,8 +971,9 @@ def training_step(self, batch, batch_idx):
966971

967972
def validation_step(self, batch, batch_idx):
968973
"""Validation step."""
969-
X, y = torch.squeeze(batch["image"], dim=1), torch.squeeze(
970-
batch["centroids_confidence_maps"], dim=1
974+
X, y = (
975+
torch.squeeze(batch["image"], dim=1),
976+
torch.squeeze(batch["centroids_confidence_maps"], dim=1),
971977
)
972978

973979
y_preds = self.model(X)["CentroidConfmapsHead"]

sleap_nn/training/model_trainer.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ def _setup_train_val_labels(
226226
if skeletons_equal:
227227
total_train_lfs += len(train_label)
228228
else:
229-
message = f"The skeletons in the training labels: {index+1} do not match the skeleton in the first training label file."
229+
message = f"The skeletons in the training labels: {index + 1} do not match the skeleton in the first training label file."
230230
logger.error(message)
231231
raise ValueError(message)
232232

@@ -291,7 +291,6 @@ def _setup_preprocessing_config(self):
291291
):
292292
# compute crop size if not provided in config
293293
if crop_size is None:
294-
295294
crop_sz = find_instance_crop_size(
296295
labels=train_label,
297296
maximum_stride=self.config.model_config.backbone_config[
@@ -358,19 +357,19 @@ def _setup_ckpt_path(self):
358357
"""Setup checkpoint path."""
359358
# if run_name is None, assign a new dir name
360359
ckpt_dir = self.config.trainer_config.ckpt_dir
361-
if ckpt_dir is None:
360+
if ckpt_dir is None or ckpt_dir == "" or ckpt_dir == "None":
362361
ckpt_dir = "."
363362
self.config.trainer_config.ckpt_dir = ckpt_dir
364363
run_name = self.config.trainer_config.run_name
365-
if run_name is None:
364+
if run_name is None or run_name == "" or run_name == "None":
366365
sum_train_lfs = sum([len(train_label) for train_label in self.train_labels])
367366
sum_val_lfs = sum([len(val_label) for val_label in self.val_labels])
368367
if self._get_trainer_devices() > 1:
369-
run_name = f"{self.model_type}.n={sum_train_lfs+sum_val_lfs}"
368+
run_name = f"{self.model_type}.n={sum_train_lfs + sum_val_lfs}"
370369
else:
371370
run_name = (
372371
datetime.now().strftime("%y%m%d_%H%M%S")
373-
+ f".{self.model_type}.n={sum_train_lfs+sum_val_lfs}"
372+
+ f".{self.model_type}.n={sum_train_lfs + sum_val_lfs}"
374373
)
375374

376375
# If checkpoint path already exists, add suffix to prevent overwriting
@@ -443,7 +442,6 @@ def _verify_model_input_channels(self):
443442
self.backbone_type == "unet"
444443
and self.config.model_config.pretrained_backbone_weights is not None
445444
):
446-
447445
if self.config.model_config.pretrained_backbone_weights.endswith(".ckpt"):
448446
pretrained_backbone_ckpt = torch.load(
449447
self.config.model_config.pretrained_backbone_weights,
@@ -648,7 +646,6 @@ def _setup_loggers_callbacks(self, viz_train_dataset, viz_val_dataset):
648646
loggers = []
649647
callbacks = []
650648
if self.config.trainer_config.save_ckpt:
651-
652649
# checkpoint callback
653650
checkpoint_callback = ModelCheckpoint(
654651
save_top_k=self.config.trainer_config.model_ckpt.save_top_k,

tests/training/test_model_trainer.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1267,6 +1267,27 @@ def test_model_ckpt_path_duplication(config, caplog, tmp_path, minimal_instance)
12671267
else:
12681268
config.trainer_config.trainer_accelerator = "auto"
12691269

1270+
# if run name is empty string
1271+
cfg_copy = config.copy()
1272+
OmegaConf.update(
1273+
cfg_copy,
1274+
"trainer_config.ckpt_dir",
1275+
f"{tmp_path}",
1276+
)
1277+
OmegaConf.update(
1278+
cfg_copy,
1279+
"trainer_config.save_ckpt",
1280+
True,
1281+
)
1282+
OmegaConf.update(cfg_copy, "trainer_config.run_name", "")
1283+
labels = sio.load_slp(minimal_instance)
1284+
trainer = ModelTrainer.get_model_trainer_from_config(
1285+
cfg_copy, train_labels=[labels], val_labels=[labels]
1286+
)
1287+
1288+
trainer.train()
1289+
1290+
# use an existing run name
12701291
config_duplicate_ckpt_path = config.copy()
12711292
OmegaConf.update(
12721293
config_duplicate_ckpt_path,

0 commit comments

Comments
 (0)