@@ -25,7 +25,6 @@ def __init__(
2525 train_transforms : Compose = None ,
2626 val_test_transforms : Compose = None ,
2727 save_predict_images : bool = False ,
28- num_classes : int = 2 ,
2928 ) -> None :
3029 """Initialize a `DirDataModule`.
3130
@@ -41,7 +40,6 @@ def __init__(
4140 train_transforms (Compose, optional): Train split transformations. Defaults to None.
4241 val_test_transforms (Compose, optional): Validation and test split transformations. Defaults to None.
4342 save_predict_images (bool, optional): Save images in predict mode? Defaults to False.
44- num_classes (int, optional): Number of classes in the dataset.
4543 """
4644 super ().__init__ ()
4745
@@ -55,29 +53,31 @@ def __init__(
5553 self .train_transforms = train_transforms
5654 self .val_test_transforms = val_test_transforms
5755 self .save_predict_images = save_predict_images
58- self ._num_classes = num_classes
5956 self .channels = channels
6057 self ._class_names : Optional [list [str ]] = None
6158 self .data_train : Optional [Dataset ] = None
6259 self .data_val : Optional [Dataset ] = None
6360 self .data_test : Optional [Dataset ] = None
6461 self .data_predict : Optional [Dataset ] = None
62+ self .setup_stages_done = set ()
6563
6664 @property
6765 def num_classes (self ) -> int :
6866 """Get the number of classes.
6967
7068 Returns:
71- int: The number of classes (2) .
69+ int: The number of classes.
7270 """
73- return self ._num_classes
71+ if self ._class_names is None and self .data_train is None :
72+ self .setup (stage = 'fit' )
73+
74+ return len (self ._class_names )
7475
7576 @property
76- def class_names (self ):
77+ def class_names (self ) -> Optional [ list [ str ]] :
7778 """Automatically extract class names from the dataset."""
78-
79- if self ._class_names is None and hasattr (self .data_train , 'classes' ):
80- self ._class_names = self .data_train .classes
79+ if self ._class_names is None and self .data_train is None :
80+ self .setup (stage = 'fit' )
8181
8282 return self ._class_names
8383
@@ -102,8 +102,10 @@ def setup(self, stage: Optional[str] = None) -> None:
102102 stage (Optional[str], optional): The stage to setup. Either `"fit"`,
103103 `"validate"`, `"test"`, or `"predict"`. Defaults to None.
104104 """
105-
106105 if stage in {'fit' , 'validate' , 'test' }:
106+ if 'fit' in self .setup_stages_done :
107+ return
108+
107109 self .data_train = ImageFolder (
108110 root = Path (self .train_data_dir ),
109111 transform = self .train_transforms ,
@@ -118,12 +120,26 @@ def setup(self, stage: Optional[str] = None) -> None:
118120 root = Path (self .val_data_dir ),
119121 transform = self .val_test_transforms ,
120122 )
123+
124+ if hasattr (self .data_train , 'classes' ) and self ._class_names is None :
125+ self ._class_names = self .data_train .classes
126+
127+ self .setup_stages_done .add ('fit' )
128+
121129 elif stage == 'predict' :
130+ if 'predict' in self .setup_stages_done :
131+ return
132+
122133 self .data_predict = ImageFolder (
123134 root = Path (self .test_data_dir ),
124135 transform = self .val_test_transforms ,
125136 )
126137
138+ if hasattr (self .data_predict , 'classes' ) and self ._class_names is None :
139+ self ._class_names = self .data_predict .classes
140+
141+ self .setup_stages_done .add ('predict' )
142+
127143 def train_dataloader (self ) -> DataLoader [Any ]:
128144 """Create and return the train dataloader.
129145
0 commit comments