@@ -431,7 +431,7 @@ def test_model_trainer_centered_instance(caplog, config, tmp_path: str):
431431 map_location = "cpu" ,
432432 weights_only = False ,
433433 )
434- assert checkpoint ["epoch" ] == 1
434+ assert checkpoint ["epoch" ] == 0 # 0-indexed: after 1 epoch completes, epoch=0
435435
436436 # check for training metrics csv
437437 path = (
@@ -1204,6 +1204,7 @@ def test_head_config_oneof_validation_error_no_head(config, caplog):
12041204 and not torch .cuda .is_available (), # self-hosted GPUs have linux os but cuda is available, so will do test
12051205 reason = "Flaky test (The training test runs on Ubuntu for a long time: >6hrs and then fails.)" ,
12061206)
1207+ # TODO: Revisit this test later (Failing on ubuntu)
12071208def test_loading_pretrained_weights (
12081209 config ,
12091210 sleap_centered_instance_model_path ,
@@ -1212,20 +1213,12 @@ def test_loading_pretrained_weights(
12121213 minimal_instance_centered_instance_ckpt ,
12131214 tmp_path ,
12141215):
1215- """Test loading pretrained weights for model initialization.
1216-
1217- Note: This test validates weight loading via log messages.
1218- Actual weight value verification is done in test_lightning_modules.py.
1219- Training is not needed to verify weight loading works.
1220- """
1221- from sleap_nn .training .lightning_modules import LightningModel
1222-
1216+ """Test loading pretrained weights for model initialization."""
12231217 if torch .mps .is_available ():
12241218 config .trainer_config .trainer_accelerator = "cpu"
12251219 else :
12261220 config .trainer_config .trainer_accelerator = "auto"
1227-
1228- # Test 1: Load keras (.h5) weights - verify log messages
1221+ # with keras (.h5 weights)
12291222 sleap_nn_config = TrainingJobConfig .load_sleap_config (
12301223 Path (sleap_centered_instance_model_path ) / "training_config.json"
12311224 )
@@ -1248,16 +1241,14 @@ def test_loading_pretrained_weights(
12481241 train_labels = [sio .load_slp (minimal_instance )],
12491242 val_labels = [sio .load_slp (minimal_instance )],
12501243 )
1251- # Just create the Lightning model (which loads weights) - no need to train
1252- _ = LightningModel .get_lightning_model_from_config (config = trainer .config )
1244+ trainer .train ()
12531245
12541246 assert "Loading backbone weights from" in caplog .text
12551247 assert "Successfully loaded 28/28 weights from legacy model" in caplog .text
12561248 assert "Loading head weights from" in caplog .text
12571249 assert "Successfully loaded 2/2 weights from legacy model" in caplog .text
12581250
1259- # Test 2: Load .ckpt weights - verify log messages
1260- caplog .clear ()
1251+ # loading `.ckpt`
12611252 sleap_nn_config = TrainingJobConfig .load_sleap_config (
12621253 Path (sleap_centered_instance_model_path ) / "initial_config.json"
12631254 )
@@ -1279,7 +1270,7 @@ def test_loading_pretrained_weights(
12791270 train_labels = [sio .load_slp (minimal_instance )],
12801271 val_labels = [sio .load_slp (minimal_instance )],
12811272 )
1282- _ = LightningModel . get_lightning_model_from_config ( config = trainer .config )
1273+ trainer .train ( )
12831274
12841275 assert "Loading backbone weights from" in caplog .text
12851276 assert "Loading head weights from" in caplog .text
0 commit comments