|
| 1 | +# Copyright 2024 Kiel University |
| 2 | +# |
| 3 | +# This source code is licensed under the MIT license found in the |
| 4 | +# LICENSE file in the root directory of this source tree. |
| 5 | + |
| 6 | + |
| 7 | +import numpy as np |
| 8 | +import torch |
| 9 | +import torch.nn as nn |
| 10 | +import torch.nn.functional as F |
| 11 | +import torch.optim as optim |
| 12 | +import matplotlib.pyplot as plt |
| 13 | +from torch.utils.data.sampler import SubsetRandomSampler |
| 14 | +from torch.utils.data import DataLoader |
| 15 | +import math |
| 16 | +from torchvision import datasets, transforms |
| 17 | +import matplotlib.pyplot as plt |
| 18 | +import torch.nn as nn |
| 19 | +import torch.nn.functional as F |
| 20 | +import seaborn as sns |
| 21 | +import pandas as pd |
| 22 | +import lightning as L |
| 23 | +from copy import copy |
| 24 | +import torchvision |
| 25 | +import random |
| 26 | +from dahuffman import HuffmanCodec |
| 27 | +import numpy as np |
| 28 | +import cv2 |
| 29 | +import torch |
| 30 | +import json |
| 31 | +from PIL import Image |
| 32 | +import torch |
| 33 | +from torchvision import transforms |
| 34 | +from torch.utils.data import Dataset |
| 35 | +import os |
| 36 | +import pickle |
| 37 | +from lightning.pytorch.loggers import WandbLogger |
| 38 | +from lightning.pytorch.callbacks import LearningRateMonitor |
| 39 | +import wandb |
| 40 | +import torchvision.models as models |
| 41 | +from torchmetrics.image import MultiScaleStructuralSimilarityIndexMeasure |
| 42 | +from compressai.layers import GDN |
| 43 | +from compressai.layers import ( |
| 44 | + AttentionBlock, |
| 45 | + ResidualBlock, |
| 46 | + ResidualBlockUpsample, |
| 47 | + ResidualBlockWithStride, |
| 48 | + conv3x3, |
| 49 | + conv1x1, |
| 50 | + subpel_conv3x3, |
| 51 | +) |
| 52 | +from compressai.models.utils import conv, deconv |
| 53 | +from torch import Tensor |
| 54 | + |
| 55 | +device = 'cuda' |
| 56 | + |
| 57 | +class ResidualBottleneckBlock(nn.Module): |
| 58 | + """Residual bottleneck block. |
| 59 | +
|
| 60 | + Introduced by [He2016], this block sandwiches a 3x3 convolution |
| 61 | + between two 1x1 convolutions which reduce and then restore the |
| 62 | + number of channels. This reduces the number of parameters required. |
| 63 | +
|
| 64 | + [He2016]: `"Deep Residual Learning for Image Recognition" |
| 65 | + <https://arxiv.org/abs/1512.03385>`_, by Kaiming He, Xiangyu Zhang, |
| 66 | + Shaoqing Ren, and Jian Sun, CVPR 2016. |
| 67 | +
|
| 68 | + Args: |
| 69 | + in_ch (int): Number of input channels |
| 70 | + out_ch (int): Number of output channels |
| 71 | + """ |
| 72 | + |
| 73 | + def __init__(self, in_ch: int, out_ch: int): |
| 74 | + super().__init__() |
| 75 | + mid_ch = min(in_ch, out_ch) // 2 |
| 76 | + self.conv1 = conv1x1(in_ch, mid_ch) |
| 77 | + self.relu1 = nn.ReLU(inplace=True) |
| 78 | + self.conv2 = conv3x3(mid_ch, mid_ch) |
| 79 | + self.relu2 = nn.ReLU(inplace=True) |
| 80 | + self.conv3 = conv1x1(mid_ch, out_ch) |
| 81 | + self.skip = conv1x1(in_ch, out_ch) if in_ch != out_ch else nn.Identity() |
| 82 | + |
| 83 | + def forward(self, x: Tensor) -> Tensor: |
| 84 | + identity = self.skip(x) |
| 85 | + |
| 86 | + out = x |
| 87 | + out = self.conv1(out) |
| 88 | + out = self.relu1(out) |
| 89 | + out = self.conv2(out) |
| 90 | + out = self.relu2(out) |
| 91 | + out = self.conv3(out) |
| 92 | + |
| 93 | + return out + identity |
| 94 | + |
| 95 | +def psnr_batch(img1, img2): |
| 96 | + mse = F.mse_loss(img1, img2, reduction='none') |
| 97 | + mse = mse.view(mse.size(0), -1).mean(dim=1) |
| 98 | + psnr_values = 20 * torch.log10(1.0 / torch.sqrt(mse)) |
| 99 | + return torch.mean(psnr_values.detach().cpu()) |
| 100 | + |
| 101 | +def ms_ssim_batch(img1, img2, data_range=1.0): |
| 102 | + |
| 103 | + ms_ssim = MultiScaleStructuralSimilarityIndexMeasure(data_range=data_range).to(device) |
| 104 | + |
| 105 | + # Calculate MS-SSIM for each image pair in the batch |
| 106 | + ms_ssim_values = [ms_ssim(img1[i].unsqueeze(0), img2[i].unsqueeze(0)).item() for i in range(img1.size(0))] |
| 107 | + |
| 108 | + # Convert MS-SSIM values to dB |
| 109 | + ms_ssim_db_values = [ms_ssim_to_db(value) for value in ms_ssim_values] |
| 110 | + |
| 111 | + return torch.tensor(ms_ssim_db_values, device='cuda') |
| 112 | + |
| 113 | +def ms_ssim_to_db(ms_ssim): |
| 114 | + return -10 * np.log10(1 - ms_ssim) |
| 115 | + |
| 116 | +#Define the Convolutional Autoencoder |
| 117 | + |
| 118 | +class Encoder(L.LightningModule): |
| 119 | + def __init__(self): |
| 120 | + super(Encoder, self).__init__() |
| 121 | + |
| 122 | + # Encoder |
| 123 | + self.conv1 = nn.Conv2d(3, 16, 7, stride=2, padding=3) |
| 124 | + self.conv2 = nn.Conv2d(16, 16, 5, stride=4, padding=1) |
| 125 | + self.conv3 = nn.Conv2d(16, 12, 3, stride=1, padding=1) |
| 126 | + |
| 127 | + |
| 128 | + def forward(self, x): |
| 129 | + x = F.relu(self.conv1(x)) |
| 130 | + x = F.relu(self.conv2(x)) |
| 131 | + x = self.conv3(x) |
| 132 | + return x |
| 133 | + |
| 134 | + |
| 135 | + |
| 136 | +class Decoder(L.LightningModule): |
| 137 | + def __init__(self): |
| 138 | + super(Decoder, self).__init__() |
| 139 | + |
| 140 | + # self.N=126 |
| 141 | + self.N=196 |
| 142 | + self.dec = nn.Sequential( |
| 143 | + AttentionBlock(12), |
| 144 | + deconv(12, self.N, kernel_size=5, stride=2), |
| 145 | + |
| 146 | + ResidualBottleneckBlock(self.N, self.N), |
| 147 | + ResidualBottleneckBlock(self.N, self.N), |
| 148 | + ResidualBottleneckBlock(self.N, self.N), |
| 149 | + AttentionBlock(self.N), |
| 150 | + ResidualBottleneckBlock(self.N, self.N), |
| 151 | + ResidualBottleneckBlock(self.N, self.N), |
| 152 | + ResidualBottleneckBlock(self.N, self.N), |
| 153 | + AttentionBlock(self.N), |
| 154 | + ResidualBottleneckBlock(self.N, self.N), |
| 155 | + ResidualBottleneckBlock(self.N, self.N), |
| 156 | + ResidualBottleneckBlock(self.N, self.N), |
| 157 | + deconv(self.N, self.N, kernel_size=5, stride=2), |
| 158 | + AttentionBlock(self.N), |
| 159 | + ResidualBottleneckBlock(self.N, self.N), |
| 160 | + ResidualBottleneckBlock(self.N, self.N), |
| 161 | + ResidualBottleneckBlock(self.N, self.N), |
| 162 | + AttentionBlock(self.N), |
| 163 | + ResidualBottleneckBlock(self.N, self.N), |
| 164 | + ResidualBottleneckBlock(self.N, self.N), |
| 165 | + ResidualBottleneckBlock(self.N, self.N), |
| 166 | + AttentionBlock(self.N), |
| 167 | + ResidualBottleneckBlock(self.N, self.N), |
| 168 | + ResidualBottleneckBlock(self.N, self.N), |
| 169 | + ResidualBottleneckBlock(self.N, self.N), |
| 170 | + deconv(self.N, 3, kernel_size=5, stride=2), |
| 171 | + ) |
| 172 | + |
| 173 | + def forward(self, x): |
| 174 | + x = self.dec(x) |
| 175 | + x = torch.sigmoid(x) |
| 176 | + return x |
| 177 | + |
| 178 | + |
| 179 | +#Define the Convolutional Autoencoder |
| 180 | + |
| 181 | +class MCUCoder(L.LightningModule): |
| 182 | + def __init__(self, loss=None): |
| 183 | + super(MCUCoder, self).__init__() |
| 184 | + # |
| 185 | + self.encoder = Encoder() |
| 186 | + self.decoder = Decoder() |
| 187 | + self.loss_func = loss |
| 188 | + self.ms_ssim = MultiScaleStructuralSimilarityIndexMeasure(data_range=1.0) |
| 189 | + |
| 190 | + self.p = None |
| 191 | + self.replace_value = 0 |
| 192 | + |
| 193 | + self.training_step_loss = [] |
| 194 | + self.validation_step_psnr = [] |
| 195 | + self.validation_step_ms_ssim = [] |
| 196 | + self.saved_images=[] |
| 197 | + |
| 198 | + def random_noise(self, x, r1, r2): |
| 199 | + temp_x = x.clone() |
| 200 | + noise = (r1 - r2) * torch.rand(x.shape) + r2 |
| 201 | + return torch.clamp(temp_x + noise.cuda(), min=0.0, max=1.0) |
| 202 | + |
| 203 | + |
| 204 | + def forward(self, x): |
| 205 | + # Encoder |
| 206 | + x = self.encoder(x) |
| 207 | + x = self.rate_less(x) |
| 208 | + |
| 209 | + # noise |
| 210 | + if not self.training: |
| 211 | + noise = torch.rand_like(x, dtype=torch.float) * 0.02 - 0.01 |
| 212 | + x = x + noise.clone() |
| 213 | + |
| 214 | + x = self.decoder(x) |
| 215 | + self.rec_image = x.detach().clone() |
| 216 | + |
| 217 | + return x |
| 218 | + |
| 219 | + def configure_optimizers(self): |
| 220 | + optimizer = torch.optim.Adam(self.parameters(), lr=1e-4, betas=(0.9, 0.999)) |
| 221 | + first_phase = int(self.trainer.max_steps * 0.95) |
| 222 | + # first_phase = 200_000 |
| 223 | + |
| 224 | + scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=first_phase, gamma=0.1) |
| 225 | + |
| 226 | + return [optimizer], [{'scheduler': scheduler, 'interval': 'step'}] |
| 227 | + |
| 228 | + |
| 229 | + def training_step(self, train_batch, batch_idx): |
| 230 | + images = train_batch |
| 231 | + |
| 232 | + outputs = self(images) |
| 233 | + |
| 234 | + if self.loss_func =='msssim': |
| 235 | + loss = 1 - self.ms_ssim(outputs, images) |
| 236 | + self.log('during_train_loss_ms_ssim', loss, on_epoch=True, prog_bar=True, logger=True) |
| 237 | + |
| 238 | + if self.loss_func =='mse': |
| 239 | + loss = nn.MSELoss()(outputs, images) |
| 240 | + self.log('during_train_loss_mse', loss, on_epoch=True, prog_bar=True, logger=True) |
| 241 | + |
| 242 | + self.training_step_loss.append(loss) |
| 243 | + |
| 244 | + return { |
| 245 | + 'loss': loss |
| 246 | + } |
| 247 | + |
| 248 | + def on_train_epoch_end(self): |
| 249 | + |
| 250 | + loss = torch.stack([x for x in self.training_step_loss]).mean() |
| 251 | + self.log('train_loss_epoch', loss, on_epoch=True, prog_bar=True, logger=True) |
| 252 | + |
| 253 | + self.training_step_loss.clear() |
| 254 | + |
| 255 | + lr = self.trainer.lr_scheduler_configs[0].scheduler.get_last_lr()[0] |
| 256 | + self.log('learning_rate', lr, on_step=False, on_epoch=True) |
| 257 | + |
| 258 | + def validation_step(self, val_batch, batch_idx): |
| 259 | + |
| 260 | + self.saved_images = [] |
| 261 | + msssim_temp = [] |
| 262 | + psnr_temp = [] |
| 263 | + |
| 264 | + for t in [2, 6, 12]: |
| 265 | + self.p = t/12 |
| 266 | + images = val_batch |
| 267 | + outputs = self(images) |
| 268 | + |
| 269 | + psnr = psnr_batch(outputs, images) |
| 270 | + psnr_temp.append(psnr) |
| 271 | + |
| 272 | + ms_ssim = ms_ssim_batch(outputs, images) |
| 273 | + msssim_temp.append(ms_ssim) |
| 274 | + |
| 275 | + self.saved_images.append(self.rec_image[0]) |
| 276 | + |
| 277 | + |
| 278 | + self.validation_step_psnr.append(psnr_temp) |
| 279 | + self.validation_step_ms_ssim.append(msssim_temp) |
| 280 | + |
| 281 | + self.p = None |
| 282 | + return {'loss': psnr, 'ms_ssim': ms_ssim} |
| 283 | + |
| 284 | + def on_validation_epoch_end(self): |
| 285 | + psnr = torch.stack([x[0] for x in self.validation_step_psnr]).mean() |
| 286 | + self.log('val_psnr_2l_epoch', psnr, on_epoch=True, prog_bar=True, logger=True) |
| 287 | + |
| 288 | + psnr = torch.stack([x[1] for x in self.validation_step_psnr]).mean() |
| 289 | + self.log('val_psnr_6l_epoch', psnr, on_epoch=True, prog_bar=True, logger=True) |
| 290 | + |
| 291 | + psnr = torch.stack([x[2] for x in self.validation_step_psnr]).mean() |
| 292 | + self.log('val_psnr_12l_epoch', psnr, on_epoch=True, prog_bar=True, logger=True) |
| 293 | + |
| 294 | + |
| 295 | + ms_ssim = torch.stack([x[0] for x in self.validation_step_ms_ssim]).mean() |
| 296 | + self.log('val_ms_ssim_2l_epoch', ms_ssim, on_epoch=True, prog_bar=True, logger=True) |
| 297 | + |
| 298 | + ms_ssim = torch.stack([x[1] for x in self.validation_step_ms_ssim]).mean() |
| 299 | + self.log('val_ms_ssim_6l_epoch', ms_ssim, on_epoch=True, prog_bar=True, logger=True) |
| 300 | + |
| 301 | + ms_ssim = torch.stack([x[2] for x in self.validation_step_ms_ssim]).mean() |
| 302 | + self.log('val_ms_ssim_12l_epoch', ms_ssim, on_epoch=True, prog_bar=True, logger=True) |
| 303 | + |
| 304 | + |
| 305 | + self.logger.experiment.log({"rec_image_2l": wandb.Image(self.saved_images[0])}) |
| 306 | + self.logger.experiment.log({"rec_image_6l": wandb.Image(self.saved_images[1])}) |
| 307 | + self.logger.experiment.log({"rec_image_12l": wandb.Image(self.saved_images[2])}) |
| 308 | + |
| 309 | + self.p = None |
| 310 | + self.validation_step_psnr.clear() |
| 311 | + self.validation_step_ms_ssim.clear() |
| 312 | + |
| 313 | + def rate_less(self,x): |
| 314 | + temp_x = x.clone() |
| 315 | + if self.p: |
| 316 | + # p shows the percentage of keeping |
| 317 | + p = self.p |
| 318 | + else: |
| 319 | + p = np.random.randint(1, 13)/12 |
| 320 | + |
| 321 | + if p != 1.0: |
| 322 | + p = int(p * x.shape[1]) |
| 323 | + replace_tensor = torch.rand(x.shape[0], x.shape[1]-p, x.shape[2], x.shape[3]).fill_(self.replace_value) |
| 324 | + temp_x[:,-(x.shape[1]-p):,:,:] = replace_tensor |
| 325 | + return temp_x |
0 commit comments