|
| 1 | +""" |
| 2 | +denoiser.py |
| 3 | +=========== |
| 4 | +
|
| 5 | +Denoising autoencoder model for image cleaning |
| 6 | +
|
| 7 | +Created by Maxim Ziatdinov (email: [email protected]) |
| 8 | +Modified with conventional batch normalization approach |
| 9 | +""" |
| 10 | + |
| 11 | +from typing import Type, Union, Optional, Tuple |
| 12 | +import torch |
| 13 | +import numpy as np |
| 14 | +from ..trainers import BaseTrainer |
| 15 | +from ..predictors import BasePredictor |
| 16 | +from ..nets import ConvBlock, UpsampleBlock |
| 17 | +from ..utils import set_train_rng, preprocess_denoiser_data |
| 18 | + |
| 19 | + |
| 20 | +class DenoisingAutoencoder(BaseTrainer): |
| 21 | + """ |
| 22 | + Denoising autoencoder model for image cleaning and noise reduction |
| 23 | + |
| 24 | + Args: |
| 25 | + encoder_filters: List of filter sizes for encoder layers (Default: [8, 16, 32, 64]) |
| 26 | + decoder_filters: List of filter sizes for decoder layers (Default: [64, 32, 16, 8]) |
| 27 | + encoder_layers: Number of convolutional layers per encoder block (Default: [1, 2, 2, 2]) |
| 28 | + decoder_layers: Number of convolutional layers per decoder block (Default: [2, 2, 2, 1]) |
| 29 | + use_batch_norm: Whether to use batch normalization in both encoder and decoder (Default: True) |
| 30 | + upsampling_mode: Upsampling method ('nearest' or 'bilinear') (Default: 'nearest') |
| 31 | + **seed: Random seed for reproducibility (Default: 1) |
| 32 | + |
| 33 | + Example: |
| 34 | + >>> # Initialize model |
| 35 | + >>> model = aoi.models.DenoisingAutoencoder() |
| 36 | + >>> # Train on noisy/clean image pairs |
| 37 | + >>> model.fit(noisy_images, clean_images, noisy_test, clean_test, |
| 38 | + >>> training_cycles=500, swa=True) |
| 39 | + >>> # Denoise new images |
| 40 | + >>> cleaned = model.predict(new_noisy_images) |
| 41 | + """ |
| 42 | + |
| 43 | + def __init__(self, |
| 44 | + encoder_filters: list = [8, 16, 32, 64], |
| 45 | + decoder_filters: list = [64, 32, 16, 8], |
| 46 | + encoder_layers: list = [1, 2, 2, 2], |
| 47 | + decoder_layers: list = [2, 2, 2, 1], |
| 48 | + use_batch_norm: bool = False, |
| 49 | + upsampling_mode: str = 'nearest', |
| 50 | + **kwargs) -> None: |
| 51 | + """ |
| 52 | + Initialize denoising autoencoder |
| 53 | + """ |
| 54 | + super(DenoisingAutoencoder, self).__init__() |
| 55 | + |
| 56 | + seed = kwargs.get("seed", 1) |
| 57 | + set_train_rng(seed) |
| 58 | + |
| 59 | + # Store architecture parameters |
| 60 | + self.encoder_filters = encoder_filters |
| 61 | + self.decoder_filters = decoder_filters |
| 62 | + self.encoder_layers = encoder_layers |
| 63 | + self.decoder_layers = decoder_layers |
| 64 | + self.use_batch_norm = use_batch_norm |
| 65 | + self.upsampling_mode = upsampling_mode |
| 66 | + |
| 67 | + # Build the autoencoder |
| 68 | + self.net = self._build_autoencoder() |
| 69 | + self.net.to(self.device) |
| 70 | + |
| 71 | + # Initialize meta state dict for saving/loading |
| 72 | + self.meta_state_dict = { |
| 73 | + "model_type": "denoising_autoencoder", |
| 74 | + "encoder_filters": encoder_filters, |
| 75 | + "decoder_filters": decoder_filters, |
| 76 | + "encoder_layers": encoder_layers, |
| 77 | + "decoder_layers": decoder_layers, |
| 78 | + "use_batch_norm": use_batch_norm, |
| 79 | + "upsampling_mode": upsampling_mode, |
| 80 | + "weights": self.net.state_dict() |
| 81 | + } |
| 82 | + |
| 83 | + def _build_autoencoder(self) -> torch.nn.Module: |
| 84 | + """ |
| 85 | + Build the encoder-decoder architecture with consistent batch norm placement |
| 86 | + """ |
| 87 | + # Build encoder |
| 88 | + encoder_modules = [] |
| 89 | + in_channels = 1 # Assuming grayscale images |
| 90 | + |
| 91 | + for i, (filters, layers) in enumerate(zip(self.encoder_filters, self.encoder_layers)): |
| 92 | + # Add convolutional block with consistent batch norm usage |
| 93 | + encoder_modules.append( |
| 94 | + ConvBlock(ndim=2, nb_layers=layers, input_channels=in_channels, |
| 95 | + output_channels=filters, batch_norm=self.use_batch_norm) |
| 96 | + ) |
| 97 | + # Add max pooling (except for the last layer) |
| 98 | + if i < len(self.encoder_filters) - 1: |
| 99 | + encoder_modules.append(torch.nn.MaxPool2d(2, 2)) |
| 100 | + in_channels = filters |
| 101 | + |
| 102 | + encoder = torch.nn.Sequential(*encoder_modules) |
| 103 | + |
| 104 | + # Build decoder |
| 105 | + decoder_modules = [] |
| 106 | + |
| 107 | + for i, (filters, layers) in enumerate(zip(self.decoder_filters, self.decoder_layers)): |
| 108 | + # Add upsampling (except for the first layer) |
| 109 | + if i > 0: |
| 110 | + decoder_modules.append( |
| 111 | + UpsampleBlock(ndim=2, input_channels=in_channels, |
| 112 | + output_channels=in_channels, mode=self.upsampling_mode) |
| 113 | + ) |
| 114 | + |
| 115 | + # Add convolutional block with same batch norm setting as encoder |
| 116 | + decoder_modules.append( |
| 117 | + ConvBlock(ndim=2, nb_layers=layers, input_channels=in_channels, |
| 118 | + output_channels=filters, batch_norm=self.use_batch_norm) |
| 119 | + ) |
| 120 | + in_channels = filters |
| 121 | + |
| 122 | + # Final output layer (no batch norm for final reconstruction) |
| 123 | + decoder_modules.append(torch.nn.Conv2d(in_channels, 1, 1)) |
| 124 | + |
| 125 | + decoder = torch.nn.Sequential(*decoder_modules) |
| 126 | + |
| 127 | + # Combine encoder and decoder |
| 128 | + autoencoder = torch.nn.Sequential(encoder, decoder) |
| 129 | + |
| 130 | + return autoencoder |
| 131 | + |
| 132 | + def fit(self, |
| 133 | + X_train: Union[np.ndarray, torch.Tensor], |
| 134 | + y_train: Union[np.ndarray, torch.Tensor], |
| 135 | + X_test: Optional[Union[np.ndarray, torch.Tensor]] = None, |
| 136 | + y_test: Optional[Union[np.ndarray, torch.Tensor]] = None, |
| 137 | + loss: str = 'mse', |
| 138 | + optimizer: Optional[Type[torch.optim.Optimizer]] = None, |
| 139 | + training_cycles: int = 500, |
| 140 | + batch_size: int = 32, |
| 141 | + compute_accuracy: bool = False, |
| 142 | + full_epoch: bool = False, |
| 143 | + swa: bool = True, |
| 144 | + perturb_weights: bool = False, |
| 145 | + **kwargs): |
| 146 | + """ |
| 147 | + Train the denoising autoencoder |
| 148 | + |
| 149 | + Args: |
| 150 | + X_train: Noisy input images for training |
| 151 | + y_train: Clean target images for training |
| 152 | + X_test: Noisy input images for testing |
| 153 | + y_test: Clean target images for testing |
| 154 | + loss: Loss function (Default: 'mse') |
| 155 | + optimizer: Optimizer (Default: Adam with lr=1e-3) |
| 156 | + training_cycles: Number of training epochs |
| 157 | + batch_size: Batch size for training |
| 158 | + compute_accuracy: Whether to compute accuracy metrics |
| 159 | + full_epoch: Whether to use full epochs |
| 160 | + swa: Whether to use stochastic weight averaging |
| 161 | + perturb_weights: Whether to use weight perturbation |
| 162 | + **kwargs: Additional arguments for training |
| 163 | + """ |
| 164 | + if X_test is None or y_test is None: |
| 165 | + from sklearn.model_selection import train_test_split |
| 166 | + X_train, X_test, y_train, y_test = train_test_split( |
| 167 | + X_train, y_train, test_size=kwargs.get("test_size", .15), |
| 168 | + shuffle=True, random_state=kwargs.get("seed", 1)) |
| 169 | + |
| 170 | + # Preprocess data |
| 171 | + X_train, y_train, X_test, y_test = preprocess_denoiser_data( |
| 172 | + X_train, y_train, X_test, y_test) |
| 173 | + |
| 174 | + # Compile and run training |
| 175 | + self.compile_trainer( |
| 176 | + (X_train, y_train, X_test, y_test), |
| 177 | + loss=loss, optimizer=optimizer, training_cycles=training_cycles, |
| 178 | + batch_size=batch_size, compute_accuracy=compute_accuracy, |
| 179 | + full_epoch=full_epoch, swa=swa, perturb_weights=perturb_weights, |
| 180 | + **kwargs |
| 181 | + ) |
| 182 | + |
| 183 | + self.run() |
| 184 | + |
| 185 | + # Update meta state dict |
| 186 | + self.meta_state_dict["weights"] = self.net.state_dict() |
| 187 | + |
| 188 | + def predict(self, |
| 189 | + data: Union[np.ndarray, torch.Tensor], |
| 190 | + **kwargs) -> np.ndarray: |
| 191 | + """ |
| 192 | + Denoise input images |
| 193 | + |
| 194 | + Args: |
| 195 | + data: Input noisy images |
| 196 | + **num_batches: Number of batches for prediction (Default: 10) |
| 197 | + |
| 198 | + Returns: |
| 199 | + Denoised images |
| 200 | + """ |
| 201 | + use_gpu = self.device == 'cuda' |
| 202 | + predictor = BasePredictor(self.net, use_gpu, **kwargs) |
| 203 | + |
| 204 | + # Ensure proper format for prediction |
| 205 | + if isinstance(data, np.ndarray): |
| 206 | + if data.ndim == 2: |
| 207 | + data = data[None, None, ...] # Add batch and channel dims |
| 208 | + elif data.ndim == 3: |
| 209 | + data = data[:, None, ...] # Add channel dim |
| 210 | + |
| 211 | + prediction = predictor.predict(data, **kwargs) |
| 212 | + |
| 213 | + return prediction.detach().cpu().numpy().squeeze() |
| 214 | + |
| 215 | + def load_weights(self, filepath: str) -> None: |
| 216 | + """ |
| 217 | + Load saved model weights |
| 218 | + """ |
| 219 | + weight_dict = torch.load(filepath, map_location=self.device) |
| 220 | + if "weights" in weight_dict: |
| 221 | + self.net.load_state_dict(weight_dict["weights"]) |
| 222 | + else: |
| 223 | + self.net.load_state_dict(weight_dict) |
| 224 | + |
| 225 | + |
| 226 | +def init_denoising_autoencoder(**kwargs) -> Tuple[Type[torch.nn.Module], dict]: |
| 227 | + """ |
| 228 | + Initialize a denoising autoencoder model |
| 229 | + |
| 230 | + Returns: |
| 231 | + Tuple of (model, meta_state_dict) |
| 232 | + """ |
| 233 | + model = DenoisingAutoencoder(**kwargs) |
| 234 | + return model.net, model.meta_state_dict |
| 235 | + |
| 236 | + |
| 237 | +# Convenience function for quick denoising |
| 238 | +def denoise_images(noisy_images: np.ndarray, |
| 239 | + clean_images: np.ndarray, |
| 240 | + test_noisy: Optional[np.ndarray] = None, |
| 241 | + test_clean: Optional[np.ndarray] = None, |
| 242 | + training_cycles: int = 500, |
| 243 | + **kwargs) -> Tuple[DenoisingAutoencoder, np.ndarray]: |
| 244 | + """ |
| 245 | + Convenience function for training a denoising autoencoder and making predictions |
| 246 | + |
| 247 | + Args: |
| 248 | + noisy_images: Training noisy images |
| 249 | + clean_images: Training clean images |
| 250 | + test_noisy: Test noisy images (optional) |
| 251 | + test_clean: Test clean images (optional) |
| 252 | + training_cycles: Number of training cycles |
| 253 | + **kwargs: Additional arguments for model and training |
| 254 | + |
| 255 | + Returns: |
| 256 | + Tuple of (trained_model, predictions_on_test_data) |
| 257 | + """ |
| 258 | + # Initialize model |
| 259 | + model = DenoisingAutoencoder(**kwargs) |
| 260 | + |
| 261 | + # Train model |
| 262 | + model.fit(noisy_images, clean_images, test_noisy, test_clean, |
| 263 | + training_cycles=training_cycles, **kwargs) |
| 264 | + |
| 265 | + # Make predictions if test data provided |
| 266 | + predictions = None |
| 267 | + if test_noisy is not None: |
| 268 | + predictions = model.predict(test_noisy) |
| 269 | + |
| 270 | + return model, predictions |
0 commit comments