Skip to content

Commit bc521e3

Browse files
committed
Merge branch 'training'
2 parents dc66a38 + 628f5b1 commit bc521e3

File tree

13 files changed

+48
-40
lines changed

13 files changed

+48
-40
lines changed

configs/data/classification_dir.yaml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,5 +26,4 @@ val_test_transforms:
2626
size: ${data.image_size}
2727
- _target_: torchvision.transforms.ToTensor
2828

29-
save_predict_images: false
30-
num_classes: ???
29+
save_predict_images: false

configs/experiment/train_cnn.yaml

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,17 +20,14 @@ trainer:
2020
data:
2121
batch_size: 32
2222
num_workers: 1
23-
num_classes: 6
2423

2524
model:
2625
net:
2726
_target_: src.models.components.base_model.BaseModel
2827
model_name: timm/mobilenetv3_large_100.ra_in1k
2928
pretrained: True
30-
num_classes: ${data.num_classes}
3129
loss:
3230
_target_: torch.nn.CrossEntropyLoss
3331
ckpt_path: null
34-
num_classes: ${data.num_classes}
3532

3633
export_to_onnx: True

configs/experiment/train_cnn_multi.yaml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,8 @@ trainer:
2020
data:
2121
batch_size: 8
2222
num_workers: 1
23-
num_classes: 2
2423

2524
model:
2625
net:
2726
multi_head: false
2827
ckpt_path: null
29-
num_classes: ${data.num_classes}

configs/experiment/train_grain.yaml

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,10 @@ tags: ["classification", "train", "cnn"]
1212
seed: 12345
1313

1414
trainer:
15-
min_epochs: 10
16-
max_epochs: 20
17-
gradient_clip_val: 0.5
15+
min_epochs: 20
16+
max_epochs: 35
1817
accelerator: gpu
19-
precision: "bf16"
18+
precision: "bf16-mixed"
2019

2120
data:
2221
batch_size: 32
@@ -42,19 +41,15 @@ data:
4241
replicate_borders: false
4342
inpaint_color: [215, 215, 215]
4443
- _target_: torchvision.transforms.ToTensor
45-
46-
num_classes: 6
4744

4845
model:
4946
net:
5047
_target_: src.models.components.base_model.BaseModel
5148
model_name: timm/mobilenetv3_large_100.ra_in1k
5249
pretrained: True
53-
num_classes: ${data.num_classes}
5450
loss:
5551
_target_: torch.nn.CrossEntropyLoss
5652
ckpt_path: null
57-
num_classes: ${data.num_classes}
5853

5954
callbacks:
6055
model_checkpoint:

configs/experiment/train_vit.yaml

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,18 +19,15 @@ trainer:
1919
data:
2020
batch_size: 64
2121
num_workers: 1
22-
num_classes: 6
2322

2423
model:
2524
net:
2625
_target_: src.models.components.base_model.BaseModel
2726
model_name: timm/vit_tiny_patch16_224.augreg_in21k_ft_in1k
2827
pretrained: True
29-
num_classes: ${data.num_classes}
3028
img_size: ${data.image_size}
3129
loss:
3230
_target_: torch.nn.CrossEntropyLoss
3331
ckpt_path: null
34-
num_classes: ${data.num_classes}
3532

3633
export_to_onnx: True

configs/experiment/train_vit_multi.yaml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ trainer:
1919
data:
2020
batch_size: 64
2121
num_workers: 1
22-
num_classes: 2
2322

2423
model:
2524
net:
@@ -28,4 +27,3 @@ model:
2827
lr: 5e-5
2928
compile: false
3029
ckpt_path: null
31-
num_classes: ${data.num_classes}

configs/model/cnn.yaml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,5 +22,3 @@ loss:
2222
compile: false
2323

2424
ckpt_path: ???
25-
26-
num_classes: ???

configs/model/vit.yaml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,4 @@ loss:
1616

1717
compile: false
1818

19-
ckpt_path: ???
20-
21-
num_classes: ???
19+
ckpt_path: ???

src/data/classification_datamodule.py

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ def __init__(
2525
train_transforms: Compose = None,
2626
val_test_transforms: Compose = None,
2727
save_predict_images: bool = False,
28-
num_classes: int = 2,
2928
) -> None:
3029
"""Initialize a `DirDataModule`.
3130
@@ -41,7 +40,6 @@ def __init__(
4140
train_transforms (Compose, optional): Train split transformations. Defaults to None.
4241
val_test_transforms (Compose, optional): Validation and test split transformations. Defaults to None.
4342
save_predict_images (bool, optional): Save images in predict mode? Defaults to False.
44-
num_classes (int, optional): Number of classes in the dataset.
4543
"""
4644
super().__init__()
4745

@@ -55,29 +53,31 @@ def __init__(
5553
self.train_transforms = train_transforms
5654
self.val_test_transforms = val_test_transforms
5755
self.save_predict_images = save_predict_images
58-
self._num_classes = num_classes
5956
self.channels = channels
6057
self._class_names: Optional[list[str]] = None
6158
self.data_train: Optional[Dataset] = None
6259
self.data_val: Optional[Dataset] = None
6360
self.data_test: Optional[Dataset] = None
6461
self.data_predict: Optional[Dataset] = None
62+
self.setup_stages_done = set()
6563

6664
@property
6765
def num_classes(self) -> int:
6866
"""Get the number of classes.
6967
7068
Returns:
71-
int: The number of classes (2).
69+
int: The number of classes.
7270
"""
73-
return self._num_classes
71+
if self._class_names is None and self.data_train is None:
72+
self.setup(stage='fit')
73+
74+
return len(self._class_names)
7475

7576
@property
76-
def class_names(self):
77+
def class_names(self) -> Optional[list[str]]:
7778
"""Automatically extract class names from the dataset."""
78-
79-
if self._class_names is None and hasattr(self.data_train, 'classes'):
80-
self._class_names = self.data_train.classes
79+
if self._class_names is None and self.data_train is None:
80+
self.setup(stage='fit')
8181

8282
return self._class_names
8383

@@ -102,8 +102,10 @@ def setup(self, stage: Optional[str] = None) -> None:
102102
stage (Optional[str], optional): The stage to setup. Either `"fit"`,
103103
`"validate"`, `"test"`, or `"predict"`. Defaults to None.
104104
"""
105-
106105
if stage in {'fit', 'validate', 'test'}:
106+
if 'fit' in self.setup_stages_done:
107+
return
108+
107109
self.data_train = ImageFolder(
108110
root=Path(self.train_data_dir),
109111
transform=self.train_transforms,
@@ -118,12 +120,26 @@ def setup(self, stage: Optional[str] = None) -> None:
118120
root=Path(self.val_data_dir),
119121
transform=self.val_test_transforms,
120122
)
123+
124+
if hasattr(self.data_train, 'classes') and self._class_names is None:
125+
self._class_names = self.data_train.classes
126+
127+
self.setup_stages_done.add('fit')
128+
121129
elif stage == 'predict':
130+
if 'predict' in self.setup_stages_done:
131+
return
132+
122133
self.data_predict = ImageFolder(
123134
root=Path(self.test_data_dir),
124135
transform=self.val_test_transforms,
125136
)
126137

138+
if hasattr(self.data_predict, 'classes') and self._class_names is None:
139+
self._class_names = self.data_predict.classes
140+
141+
self.setup_stages_done.add('predict')
142+
127143
def train_dataloader(self) -> DataLoader[Any]:
128144
"""Create and return the train dataloader.
129145

src/models/classification_module.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def __init__(
1717
loss: torch.nn.modules.loss,
1818
compile: bool,
1919
ckpt_path: str,
20-
num_classes: int = None,
20+
num_classes: int,
2121
) -> None:
2222
"""Initialize lightning module.
2323
@@ -28,7 +28,7 @@ def __init__(
2828
loss (torch.nn.modules.loss): Loss function.
2929
compile (bool): Compile model.
3030
ckpt_path (string): Model chekpoint path.
31-
num_classes (int, optional): Number of classes.
31+
num_classes (int): Number of classes.
3232
"""
3333
super().__init__()
3434
# model
@@ -220,6 +220,7 @@ def setup(self, stage: str) -> None:
220220
Args:
221221
stage (str): Either `"fit"`, `"validate"`, `"test"`, or `"predict"`.
222222
"""
223+
self.net.load_model(num_classes=self.num_classes)
223224
if self.compile and stage == 'fit':
224225
self.net = torch.compile(self.net)
225226
if self.ckpt_path:

0 commit comments

Comments
 (0)