Skip to content

Commit a288410

Browse files
authored
Merge pull request #93 from ziatdinovmax/master
Add a native image denoiser
2 parents 17528af + 7fcba6a commit a288410

File tree

9 files changed

+1766
-2
lines changed

9 files changed

+1766
-2
lines changed

atomai/models/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,13 @@
22
from .imspec import ImSpec
33
from .regressor import Regressor
44
from .classifier import Classifier
5+
from .denoiser import DenoisingAutoencoder, denoise_images
56
from .dgm import BaseVAE, VAE, rVAE, jVAE, jrVAE
67
from .dklgp import dklGPR, Reconstructor
78
from .loaders import load_model, load_ensemble, load_pretrained_model
89

910
__all__ = ["Segmentor", "ImSpec", "BaseVAE", "VAE", "rVAE",
1011
"jVAE", "jrVAE", "load_model", "load_ensemble",
1112
"load_pretrained_model", "dklGPR", "Regressor",
12-
"Classifier", "Reconstructor"]
13+
"Classifier", "Reconstructor", "DenoisingAutoencoder",
14+
"denoise_images"]

atomai/models/denoiser.py

Lines changed: 270 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,270 @@
1+
"""
2+
denoiser.py
3+
===========
4+
5+
Denoising autoencoder model for image cleaning
6+
7+
Created by Maxim Ziatdinov (email: [email protected])
8+
Modified with conventional batch normalization approach
9+
"""
10+
11+
from typing import Type, Union, Optional, Tuple
12+
import torch
13+
import numpy as np
14+
from ..trainers import BaseTrainer
15+
from ..predictors import BasePredictor
16+
from ..nets import ConvBlock, UpsampleBlock
17+
from ..utils import set_train_rng, preprocess_denoiser_data
18+
19+
20+
class DenoisingAutoencoder(BaseTrainer):
21+
"""
22+
Denoising autoencoder model for image cleaning and noise reduction
23+
24+
Args:
25+
encoder_filters: List of filter sizes for encoder layers (Default: [8, 16, 32, 64])
26+
decoder_filters: List of filter sizes for decoder layers (Default: [64, 32, 16, 8])
27+
encoder_layers: Number of convolutional layers per encoder block (Default: [1, 2, 2, 2])
28+
decoder_layers: Number of convolutional layers per decoder block (Default: [2, 2, 2, 1])
29+
use_batch_norm: Whether to use batch normalization in both encoder and decoder (Default: True)
30+
upsampling_mode: Upsampling method ('nearest' or 'bilinear') (Default: 'nearest')
31+
**seed: Random seed for reproducibility (Default: 1)
32+
33+
Example:
34+
>>> # Initialize model
35+
>>> model = aoi.models.DenoisingAutoencoder()
36+
>>> # Train on noisy/clean image pairs
37+
>>> model.fit(noisy_images, clean_images, noisy_test, clean_test,
38+
>>> training_cycles=500, swa=True)
39+
>>> # Denoise new images
40+
>>> cleaned = model.predict(new_noisy_images)
41+
"""
42+
43+
def __init__(self,
44+
encoder_filters: list = [8, 16, 32, 64],
45+
decoder_filters: list = [64, 32, 16, 8],
46+
encoder_layers: list = [1, 2, 2, 2],
47+
decoder_layers: list = [2, 2, 2, 1],
48+
use_batch_norm: bool = False,
49+
upsampling_mode: str = 'nearest',
50+
**kwargs) -> None:
51+
"""
52+
Initialize denoising autoencoder
53+
"""
54+
super(DenoisingAutoencoder, self).__init__()
55+
56+
seed = kwargs.get("seed", 1)
57+
set_train_rng(seed)
58+
59+
# Store architecture parameters
60+
self.encoder_filters = encoder_filters
61+
self.decoder_filters = decoder_filters
62+
self.encoder_layers = encoder_layers
63+
self.decoder_layers = decoder_layers
64+
self.use_batch_norm = use_batch_norm
65+
self.upsampling_mode = upsampling_mode
66+
67+
# Build the autoencoder
68+
self.net = self._build_autoencoder()
69+
self.net.to(self.device)
70+
71+
# Initialize meta state dict for saving/loading
72+
self.meta_state_dict = {
73+
"model_type": "denoising_autoencoder",
74+
"encoder_filters": encoder_filters,
75+
"decoder_filters": decoder_filters,
76+
"encoder_layers": encoder_layers,
77+
"decoder_layers": decoder_layers,
78+
"use_batch_norm": use_batch_norm,
79+
"upsampling_mode": upsampling_mode,
80+
"weights": self.net.state_dict()
81+
}
82+
83+
def _build_autoencoder(self) -> torch.nn.Module:
84+
"""
85+
Build the encoder-decoder architecture with consistent batch norm placement
86+
"""
87+
# Build encoder
88+
encoder_modules = []
89+
in_channels = 1 # Assuming grayscale images
90+
91+
for i, (filters, layers) in enumerate(zip(self.encoder_filters, self.encoder_layers)):
92+
# Add convolutional block with consistent batch norm usage
93+
encoder_modules.append(
94+
ConvBlock(ndim=2, nb_layers=layers, input_channels=in_channels,
95+
output_channels=filters, batch_norm=self.use_batch_norm)
96+
)
97+
# Add max pooling (except for the last layer)
98+
if i < len(self.encoder_filters) - 1:
99+
encoder_modules.append(torch.nn.MaxPool2d(2, 2))
100+
in_channels = filters
101+
102+
encoder = torch.nn.Sequential(*encoder_modules)
103+
104+
# Build decoder
105+
decoder_modules = []
106+
107+
for i, (filters, layers) in enumerate(zip(self.decoder_filters, self.decoder_layers)):
108+
# Add upsampling (except for the first layer)
109+
if i > 0:
110+
decoder_modules.append(
111+
UpsampleBlock(ndim=2, input_channels=in_channels,
112+
output_channels=in_channels, mode=self.upsampling_mode)
113+
)
114+
115+
# Add convolutional block with same batch norm setting as encoder
116+
decoder_modules.append(
117+
ConvBlock(ndim=2, nb_layers=layers, input_channels=in_channels,
118+
output_channels=filters, batch_norm=self.use_batch_norm)
119+
)
120+
in_channels = filters
121+
122+
# Final output layer (no batch norm for final reconstruction)
123+
decoder_modules.append(torch.nn.Conv2d(in_channels, 1, 1))
124+
125+
decoder = torch.nn.Sequential(*decoder_modules)
126+
127+
# Combine encoder and decoder
128+
autoencoder = torch.nn.Sequential(encoder, decoder)
129+
130+
return autoencoder
131+
132+
def fit(self,
133+
X_train: Union[np.ndarray, torch.Tensor],
134+
y_train: Union[np.ndarray, torch.Tensor],
135+
X_test: Optional[Union[np.ndarray, torch.Tensor]] = None,
136+
y_test: Optional[Union[np.ndarray, torch.Tensor]] = None,
137+
loss: str = 'mse',
138+
optimizer: Optional[Type[torch.optim.Optimizer]] = None,
139+
training_cycles: int = 500,
140+
batch_size: int = 32,
141+
compute_accuracy: bool = False,
142+
full_epoch: bool = False,
143+
swa: bool = True,
144+
perturb_weights: bool = False,
145+
**kwargs):
146+
"""
147+
Train the denoising autoencoder
148+
149+
Args:
150+
X_train: Noisy input images for training
151+
y_train: Clean target images for training
152+
X_test: Noisy input images for testing
153+
y_test: Clean target images for testing
154+
loss: Loss function (Default: 'mse')
155+
optimizer: Optimizer (Default: Adam with lr=1e-3)
156+
training_cycles: Number of training epochs
157+
batch_size: Batch size for training
158+
compute_accuracy: Whether to compute accuracy metrics
159+
full_epoch: Whether to use full epochs
160+
swa: Whether to use stochastic weight averaging
161+
perturb_weights: Whether to use weight perturbation
162+
**kwargs: Additional arguments for training
163+
"""
164+
if X_test is None or y_test is None:
165+
from sklearn.model_selection import train_test_split
166+
X_train, X_test, y_train, y_test = train_test_split(
167+
X_train, y_train, test_size=kwargs.get("test_size", .15),
168+
shuffle=True, random_state=kwargs.get("seed", 1))
169+
170+
# Preprocess data
171+
X_train, y_train, X_test, y_test = preprocess_denoiser_data(
172+
X_train, y_train, X_test, y_test)
173+
174+
# Compile and run training
175+
self.compile_trainer(
176+
(X_train, y_train, X_test, y_test),
177+
loss=loss, optimizer=optimizer, training_cycles=training_cycles,
178+
batch_size=batch_size, compute_accuracy=compute_accuracy,
179+
full_epoch=full_epoch, swa=swa, perturb_weights=perturb_weights,
180+
**kwargs
181+
)
182+
183+
self.run()
184+
185+
# Update meta state dict
186+
self.meta_state_dict["weights"] = self.net.state_dict()
187+
188+
def predict(self,
189+
data: Union[np.ndarray, torch.Tensor],
190+
**kwargs) -> np.ndarray:
191+
"""
192+
Denoise input images
193+
194+
Args:
195+
data: Input noisy images
196+
**num_batches: Number of batches for prediction (Default: 10)
197+
198+
Returns:
199+
Denoised images
200+
"""
201+
use_gpu = self.device == 'cuda'
202+
predictor = BasePredictor(self.net, use_gpu, **kwargs)
203+
204+
# Ensure proper format for prediction
205+
if isinstance(data, np.ndarray):
206+
if data.ndim == 2:
207+
data = data[None, None, ...] # Add batch and channel dims
208+
elif data.ndim == 3:
209+
data = data[:, None, ...] # Add channel dim
210+
211+
prediction = predictor.predict(data, **kwargs)
212+
213+
return prediction.detach().cpu().numpy().squeeze()
214+
215+
def load_weights(self, filepath: str) -> None:
216+
"""
217+
Load saved model weights
218+
"""
219+
weight_dict = torch.load(filepath, map_location=self.device)
220+
if "weights" in weight_dict:
221+
self.net.load_state_dict(weight_dict["weights"])
222+
else:
223+
self.net.load_state_dict(weight_dict)
224+
225+
226+
def init_denoising_autoencoder(**kwargs) -> Tuple[Type[torch.nn.Module], dict]:
227+
"""
228+
Initialize a denoising autoencoder model
229+
230+
Returns:
231+
Tuple of (model, meta_state_dict)
232+
"""
233+
model = DenoisingAutoencoder(**kwargs)
234+
return model.net, model.meta_state_dict
235+
236+
237+
# Convenience function for quick denoising
238+
def denoise_images(noisy_images: np.ndarray,
239+
clean_images: np.ndarray,
240+
test_noisy: Optional[np.ndarray] = None,
241+
test_clean: Optional[np.ndarray] = None,
242+
training_cycles: int = 500,
243+
**kwargs) -> Tuple[DenoisingAutoencoder, np.ndarray]:
244+
"""
245+
Convenience function for training a denoising autoencoder and making predictions
246+
247+
Args:
248+
noisy_images: Training noisy images
249+
clean_images: Training clean images
250+
test_noisy: Test noisy images (optional)
251+
test_clean: Test clean images (optional)
252+
training_cycles: Number of training cycles
253+
**kwargs: Additional arguments for model and training
254+
255+
Returns:
256+
Tuple of (trained_model, predictions_on_test_data)
257+
"""
258+
# Initialize model
259+
model = DenoisingAutoencoder(**kwargs)
260+
261+
# Train model
262+
model.fit(noisy_images, clean_images, test_noisy, test_clean,
263+
training_cycles=training_cycles, **kwargs)
264+
265+
# Make predictions if test data provided
266+
predictions = None
267+
if test_noisy is not None:
268+
predictions = model.predict(test_noisy)
269+
270+
return model, predictions

atomai/models/loaders.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from .imspec import ImSpec
1818
from .regressor import Regressor
1919
from .classifier import Classifier
20+
from .denoiser import DenoisingAutoencoder
2021
from .dgm import BaseVAE, VAE, rVAE, jrVAE, jVAE
2122
from ..utils import average_weights
2223

@@ -49,6 +50,8 @@ def load_model(filepath: str) -> Union[Segmentor, Union[VAE, rVAE, jrVAE, jVAE],
4950
model = load_cls_model(loaded_dict)
5051
elif model_type == "vae":
5152
model = load_vae_model(loaded_dict)
53+
elif model_type == "denoising_autoencoder":
54+
model = load_denoising_autoencoder(loaded_dict)
5255
else:
5356
raise ValueError(
5457
"The model type {} cannot be loaded".format(model_type))
@@ -192,6 +195,46 @@ def load_vae_model(meta_dict: Dict[str, torch.Tensor]) -> Type[BaseVAE]:
192195
return m
193196

194197

198+
def load_denoising_autoencoder(meta_dict: Dict[str, torch.Tensor]) -> Type[DenoisingAutoencoder]:
199+
"""
200+
Loads trained AtomAI denoising autoencoder models
201+
202+
Args:
203+
meta_dict (dict):
204+
dictionary with trained weights and key information
205+
about model's structure
206+
207+
Returns:
208+
DenoisingAutoencoder object with NN in evaluation state
209+
"""
210+
from .denoiser import DenoisingAutoencoder
211+
212+
encoder_filters = meta_dict.pop("encoder_filters", [8, 16, 32, 64])
213+
decoder_filters = meta_dict.pop("decoder_filters", [64, 32, 16, 8])
214+
encoder_layers = meta_dict.pop("encoder_layers", [1, 2, 2, 2])
215+
decoder_layers = meta_dict.pop("decoder_layers", [2, 2, 2, 1])
216+
use_batch_norm = meta_dict.pop("use_batch_norm", True)
217+
upsampling_mode = meta_dict.pop("upsampling_mode", 'nearest')
218+
weights = meta_dict.pop("weights")
219+
220+
model = DenoisingAutoencoder(
221+
encoder_filters=encoder_filters,
222+
decoder_filters=decoder_filters,
223+
encoder_layers=encoder_layers,
224+
decoder_layers=decoder_layers,
225+
use_batch_norm=use_batch_norm,
226+
upsampling_mode=upsampling_mode,
227+
**meta_dict
228+
)
229+
230+
model.net.load_state_dict(weights)
231+
if "optimizer" in meta_dict.keys():
232+
optimizer = meta_dict.pop("optimizer")
233+
model.optimizer = optimizer
234+
model.net.eval()
235+
return model
236+
237+
195238
def load_ensemble(filepath: str) -> Tuple[Type[torch.nn.Module], Dict[int, Dict[str, torch.Tensor]]]:
196239
"""
197240
Loads trained ensemble models

atomai/stat/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from .multivar import (imlocal, calculate_transition_matrix,
22
sum_transitions, update_classes)
3+
from .fft_nmf import SlidingFFTNMF
34

45
__all__ = ['imlocal', 'calculate_transition_matrix', 'sum_transitions',
5-
'update_classes']
6+
'update_classes', 'SlidingFFTNMF']

0 commit comments

Comments
 (0)