Skip to content

Commit 25da774

Browse files
committed
Merge branch
2 parents 314c607 + b843410 commit 25da774

File tree

3 files changed

+2
-68
lines changed

3 files changed

+2
-68
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ dependencies = [
2828
"av",
2929
"kornia",
3030
"hydra-core",
31-
"sleap-io>=0.1.0",
31+
"sleap-io==0.1.10",
3232
]
3333
dynamic = ["version", "readme"]
3434

sleap_nn/training/model_trainer.py

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -390,22 +390,8 @@ def train(
390390
self,
391391
backbone_trained_ckpts_path: Optional[str] = None,
392392
head_trained_ckpts_path: Optional[str] = None,
393-
delete_bin_files_after_training: bool = True,
394-
chunks_dir_path: Optional[str] = None,
395393
):
396-
"""Initiate the training by calling the fit method of Trainer.
397-
398-
Args:
399-
backbone_trained_ckpts_path: Path of the `ckpt` file with which the backbone
400-
is initialized. If `None`, random init is used.
401-
head_trained_ckpts_path: Path of the `ckpt` file with which the head layers
402-
are initialized. If `None`, random init is used.
403-
delete_bin_files_after_training: If `False`, the `bin` files are retained after
404-
training. Else, the `bin` files are deleted.
405-
chunks_dir_path: Path to chunks dir (this dir should contain `train_chunks`
406-
and `val_chunks` folder.). If `None`, `bin` files are generated.
407-
408-
"""
394+
"""Initiate the training by calling the fit method of Trainer."""
409395
logger = []
410396

411397
if self.config.trainer_config.save_ckpt:

tests/training/test_model_trainer.py

Lines changed: 0 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -376,58 +376,6 @@ def test_trainer_load_trained_ckpts(config, tmp_path, minimal_instance_ckpt):
376376
assert np.all(np.abs(head_layer_ckpt - model_ckpt) < 1e-6)
377377

378378

379-
@pytest.mark.skipif(
380-
sys.platform.startswith("li"),
381-
reason="Flaky test (The training test runs on Ubuntu for a long time: >6hrs and then fails.)",
382-
)
383-
# TODO: Revisit this test later (Failing on ubuntu)
384-
def test_reuse_bin_files(config, tmp_path: str):
385-
"""Test reusing `.bin` files."""
386-
# Centroid model
387-
centroid_config = config.copy()
388-
head_config = config.model_config.head_configs.centered_instance
389-
OmegaConf.update(centroid_config, "model_config.head_configs.centroid", head_config)
390-
del centroid_config.model_config.head_configs.centered_instance
391-
del centroid_config.model_config.head_configs.centroid["confmaps"].part_names
392-
393-
OmegaConf.update(
394-
centroid_config,
395-
"trainer_config.save_ckpt_path",
396-
f"{tmp_path}/test_model_trainer/",
397-
)
398-
399-
if (Path(centroid_config.trainer_config.save_ckpt_path) / "best.ckpt").exists():
400-
os.remove(
401-
(
402-
Path(centroid_config.trainer_config.save_ckpt_path) / "best.ckpt"
403-
).as_posix()
404-
)
405-
os.remove(
406-
(
407-
Path(centroid_config.trainer_config.save_ckpt_path) / "last.ckpt"
408-
).as_posix()
409-
)
410-
shutil.rmtree(
411-
(
412-
Path(centroid_config.trainer_config.save_ckpt_path) / "lightning_logs"
413-
).as_posix()
414-
)
415-
416-
OmegaConf.update(centroid_config, "trainer_config.save_ckpt", True)
417-
OmegaConf.update(centroid_config, "trainer_config.use_wandb", False)
418-
OmegaConf.update(centroid_config, "trainer_config.max_epochs", 1)
419-
OmegaConf.update(centroid_config, "trainer_config.steps_per_epoch", 10)
420-
421-
# test reusing bin files
422-
trainer1 = ModelTrainer(centroid_config)
423-
trainer1.train(delete_bin_files_after_training=False)
424-
425-
trainer2 = ModelTrainer(centroid_config)
426-
trainer2.train(
427-
chunks_dir_path=(trainer1.train_input_dir).split("train_chunks")[0],
428-
)
429-
430-
431379
def test_topdown_centered_instance_model(config, tmp_path: str):
432380

433381
# unet

0 commit comments

Comments
 (0)