Skip to content

Commit ef56fd4

Browse files
committed
Merge branch 'training'
2 parents bc521e3 + e0fdb06 commit ef56fd4

File tree

4 files changed

+9
-6
lines changed

4 files changed

+9
-6
lines changed

src/data/classification_datamodule.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -123,21 +123,21 @@ def setup(self, stage: Optional[str] = None) -> None:
123123

124124
if hasattr(self.data_train, 'classes') and self._class_names is None:
125125
self._class_names = self.data_train.classes
126-
126+
127127
self.setup_stages_done.add('fit')
128128

129129
elif stage == 'predict':
130130
if 'predict' in self.setup_stages_done:
131131
return
132-
132+
133133
self.data_predict = ImageFolder(
134134
root=Path(self.test_data_dir),
135135
transform=self.val_test_transforms,
136136
)
137137

138138
if hasattr(self.data_predict, 'classes') and self._class_names is None:
139139
self._class_names = self.data_predict.classes
140-
140+
141141
self.setup_stages_done.add('predict')
142142

143143
def train_dataloader(self) -> DataLoader[Any]:

src/models/components/base_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,10 +63,10 @@ def __init__(
6363
self.model_repo = model_repo
6464
self.model_kwargs = kwargs
6565
self.model = None
66-
66+
6767
def load_model(self, **additional_kwargs: Any) -> None:
6868
"""Load or reload the model with saved kwargs and optional additional kwargs.
69-
69+
7070
Args:
7171
**additional_kwargs: Additional keyword arguments to override or extend saved kwargs.
7272
"""

src/models/mnist_module.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ def __init__(
4848
scheduler: torch.optim.lr_scheduler,
4949
compile: bool,
5050
ckpt_path: str,
51+
num_classes: int = 10,
5152
) -> None:
5253
"""Initialize a `MNISTLitModule`.
5354

src/train.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,9 @@ def train(cfg: DictConfig) -> tuple[dict[str, Any], dict[str, Any]]:
4848
datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data)
4949

5050
log.info(f'Instantiating model <{cfg.model._target_}>')
51-
model: LightningModule = hydra.utils.instantiate(cfg.model, num_classes=datamodule.num_classes)
51+
model: LightningModule = hydra.utils.instantiate(
52+
cfg.model, num_classes=datamodule.num_classes
53+
)
5254

5355
log.info('Instantiating loggers...')
5456
loggers: list[Logger] = instantiate_loggers(cfg.get('logger'))

0 commit comments

Comments
 (0)