From 3645cab254898cb36a17f3a9083623e78bd2afee Mon Sep 17 00:00:00 2001 From: Bryn Lloyd <12702862+dyollb@users.noreply.github.com> Date: Thu, 7 Nov 2024 11:20:13 +0100 Subject: [PATCH 1/3] remove model parameters in predict --- src/segmantic/seg/monai_unet.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/segmantic/seg/monai_unet.py b/src/segmantic/seg/monai_unet.py index 3a1049d..129339b 100644 --- a/src/segmantic/seg/monai_unet.py +++ b/src/segmantic/seg/monai_unet.py @@ -554,9 +554,6 @@ def predict( test_labels: Optional[list[Path]] = None, output_dir: Path = None, tissue_dict: dict[str, int] = None, - channels: tuple[int, ...] = (16, 32, 64, 128, 256), - strides: tuple[int, ...] = (2, 2, 2, 2), - dropout: float = 0.0, spacing: Sequence[float] = [], gpu_ids: list[int] = [], ) -> None: @@ -569,7 +566,7 @@ def predict( net: Net = Net.load_from_checkpoint(f"{model_file}", **settings) else: net = Net.load_from_checkpoint( - f"{model_file}", channels=channels, strides=strides, dropout=dropout + f"{model_file}" ) num_classes = net.num_classes From 397ff7cfc794bae9d110c3eed0e1bfc85d5284e4 Mon Sep 17 00:00:00 2001 From: Bryn Lloyd <12702862+dyollb@users.noreply.github.com> Date: Thu, 7 Nov 2024 11:28:32 +0100 Subject: [PATCH 2/3] forgot pre-commit --- src/segmantic/seg/monai_unet.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/segmantic/seg/monai_unet.py b/src/segmantic/seg/monai_unet.py index 129339b..0d0887c 100644 --- a/src/segmantic/seg/monai_unet.py +++ b/src/segmantic/seg/monai_unet.py @@ -565,9 +565,7 @@ def predict( settings = json.load(json_file) net: Net = Net.load_from_checkpoint(f"{model_file}", **settings) else: - net = Net.load_from_checkpoint( - f"{model_file}" - ) + net = Net.load_from_checkpoint(f"{model_file}") num_classes = net.num_classes net.freeze() From 9192e218a67948147ba421b860eb46ce6af52aef Mon Sep 17 00:00:00 2001 From: Bryn Lloyd <12702862+dyollb@users.noreply.github.com> Date: Thu, 7 Nov 2024 14:24:38 +0100 Subject: [PATCH 3/3] fix static type errors --- src/segmantic/seg/monai_unet.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/segmantic/seg/monai_unet.py b/src/segmantic/seg/monai_unet.py index 0d0887c..0cdcf19 100644 --- a/src/segmantic/seg/monai_unet.py +++ b/src/segmantic/seg/monai_unet.py @@ -760,7 +760,7 @@ def cross_validate( for config_file in Path(config_files_dir).iterdir(): assert config_file.suffix in [".json", ".yml"], f"suffix: {config_file}" - is_json = config_file and config_file.suffix.lower() == ".json" + is_json = config_file.suffix.lower() == ".json" dumps = partial(config.dumps, is_json=is_json) loads = partial(config.loads, is_json=is_json) @@ -818,9 +818,6 @@ def cross_validate( test_images=test_images, test_labels=test_labels, tissue_dict=tissue_dict, - # channels=current_layers, - # strides=current_strides, - dropout=0.0, spacing=[1, 1, 1], gpu_ids=gpu_ids, )