Skip to content

Commit 2248909

Browse files
author
Ali Hojjat
committed
First files
1 parent 5e1053f commit 2248909

14 files changed

+1181
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from .data_preprocessing import *
2+
from .model import *
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# Copyright 2024 Kiel University
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import lightning as L
7+
import torch
8+
from torchvision import datasets, transforms
9+
from torch.utils.data import Dataset
10+
import os
11+
from PIL import Image
12+
13+
class CustomImageDataset(Dataset):
14+
def __init__(self, root_dir, transform=None):
15+
self.root_dir = root_dir
16+
self.transform = transform
17+
self.image_files = [f for f in os.listdir(root_dir) if os.path.isfile(os.path.join(root_dir, f))]
18+
19+
def __len__(self):
20+
return len(self.image_files)
21+
22+
def __getitem__(self, idx):
23+
img_name = os.path.join(self.root_dir, self.image_files[idx])
24+
image = Image.open(img_name)
25+
if self.transform:
26+
image = self.transform(image)
27+
return image
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,325 @@
1+
# Copyright 2024 Kiel University
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
7+
import numpy as np
8+
import torch
9+
import torch.nn as nn
10+
import torch.nn.functional as F
11+
import torch.optim as optim
12+
import matplotlib.pyplot as plt
13+
from torch.utils.data.sampler import SubsetRandomSampler
14+
from torch.utils.data import DataLoader
15+
import math
16+
from torchvision import datasets, transforms
17+
import matplotlib.pyplot as plt
18+
import torch.nn as nn
19+
import torch.nn.functional as F
20+
import seaborn as sns
21+
import pandas as pd
22+
import lightning as L
23+
from copy import copy
24+
import torchvision
25+
import random
26+
from dahuffman import HuffmanCodec
27+
import numpy as np
28+
import cv2
29+
import torch
30+
import json
31+
from PIL import Image
32+
import torch
33+
from torchvision import transforms
34+
from torch.utils.data import Dataset
35+
import os
36+
import pickle
37+
from lightning.pytorch.loggers import WandbLogger
38+
from lightning.pytorch.callbacks import LearningRateMonitor
39+
import wandb
40+
import torchvision.models as models
41+
from torchmetrics.image import MultiScaleStructuralSimilarityIndexMeasure
42+
from compressai.layers import GDN
43+
from compressai.layers import (
44+
AttentionBlock,
45+
ResidualBlock,
46+
ResidualBlockUpsample,
47+
ResidualBlockWithStride,
48+
conv3x3,
49+
conv1x1,
50+
subpel_conv3x3,
51+
)
52+
from compressai.models.utils import conv, deconv
53+
from torch import Tensor
54+
55+
device = 'cuda'
56+
57+
class ResidualBottleneckBlock(nn.Module):
58+
"""Residual bottleneck block.
59+
60+
Introduced by [He2016], this block sandwiches a 3x3 convolution
61+
between two 1x1 convolutions which reduce and then restore the
62+
number of channels. This reduces the number of parameters required.
63+
64+
[He2016]: `"Deep Residual Learning for Image Recognition"
65+
<https://arxiv.org/abs/1512.03385>`_, by Kaiming He, Xiangyu Zhang,
66+
Shaoqing Ren, and Jian Sun, CVPR 2016.
67+
68+
Args:
69+
in_ch (int): Number of input channels
70+
out_ch (int): Number of output channels
71+
"""
72+
73+
def __init__(self, in_ch: int, out_ch: int):
74+
super().__init__()
75+
mid_ch = min(in_ch, out_ch) // 2
76+
self.conv1 = conv1x1(in_ch, mid_ch)
77+
self.relu1 = nn.ReLU(inplace=True)
78+
self.conv2 = conv3x3(mid_ch, mid_ch)
79+
self.relu2 = nn.ReLU(inplace=True)
80+
self.conv3 = conv1x1(mid_ch, out_ch)
81+
self.skip = conv1x1(in_ch, out_ch) if in_ch != out_ch else nn.Identity()
82+
83+
def forward(self, x: Tensor) -> Tensor:
84+
identity = self.skip(x)
85+
86+
out = x
87+
out = self.conv1(out)
88+
out = self.relu1(out)
89+
out = self.conv2(out)
90+
out = self.relu2(out)
91+
out = self.conv3(out)
92+
93+
return out + identity
94+
95+
def psnr_batch(img1, img2):
96+
mse = F.mse_loss(img1, img2, reduction='none')
97+
mse = mse.view(mse.size(0), -1).mean(dim=1)
98+
psnr_values = 20 * torch.log10(1.0 / torch.sqrt(mse))
99+
return torch.mean(psnr_values.detach().cpu())
100+
101+
def ms_ssim_batch(img1, img2, data_range=1.0):
102+
103+
ms_ssim = MultiScaleStructuralSimilarityIndexMeasure(data_range=data_range).to(device)
104+
105+
# Calculate MS-SSIM for each image pair in the batch
106+
ms_ssim_values = [ms_ssim(img1[i].unsqueeze(0), img2[i].unsqueeze(0)).item() for i in range(img1.size(0))]
107+
108+
# Convert MS-SSIM values to dB
109+
ms_ssim_db_values = [ms_ssim_to_db(value) for value in ms_ssim_values]
110+
111+
return torch.tensor(ms_ssim_db_values, device='cuda')
112+
113+
def ms_ssim_to_db(ms_ssim):
114+
return -10 * np.log10(1 - ms_ssim)
115+
116+
#Define the Convolutional Autoencoder
117+
118+
class Encoder(L.LightningModule):
119+
def __init__(self):
120+
super(Encoder, self).__init__()
121+
122+
# Encoder
123+
self.conv1 = nn.Conv2d(3, 16, 7, stride=2, padding=3)
124+
self.conv2 = nn.Conv2d(16, 16, 5, stride=4, padding=1)
125+
self.conv3 = nn.Conv2d(16, 12, 3, stride=1, padding=1)
126+
127+
128+
def forward(self, x):
129+
x = F.relu(self.conv1(x))
130+
x = F.relu(self.conv2(x))
131+
x = self.conv3(x)
132+
return x
133+
134+
135+
136+
class Decoder(L.LightningModule):
137+
def __init__(self):
138+
super(Decoder, self).__init__()
139+
140+
# self.N=126
141+
self.N=196
142+
self.dec = nn.Sequential(
143+
AttentionBlock(12),
144+
deconv(12, self.N, kernel_size=5, stride=2),
145+
146+
ResidualBottleneckBlock(self.N, self.N),
147+
ResidualBottleneckBlock(self.N, self.N),
148+
ResidualBottleneckBlock(self.N, self.N),
149+
AttentionBlock(self.N),
150+
ResidualBottleneckBlock(self.N, self.N),
151+
ResidualBottleneckBlock(self.N, self.N),
152+
ResidualBottleneckBlock(self.N, self.N),
153+
AttentionBlock(self.N),
154+
ResidualBottleneckBlock(self.N, self.N),
155+
ResidualBottleneckBlock(self.N, self.N),
156+
ResidualBottleneckBlock(self.N, self.N),
157+
deconv(self.N, self.N, kernel_size=5, stride=2),
158+
AttentionBlock(self.N),
159+
ResidualBottleneckBlock(self.N, self.N),
160+
ResidualBottleneckBlock(self.N, self.N),
161+
ResidualBottleneckBlock(self.N, self.N),
162+
AttentionBlock(self.N),
163+
ResidualBottleneckBlock(self.N, self.N),
164+
ResidualBottleneckBlock(self.N, self.N),
165+
ResidualBottleneckBlock(self.N, self.N),
166+
AttentionBlock(self.N),
167+
ResidualBottleneckBlock(self.N, self.N),
168+
ResidualBottleneckBlock(self.N, self.N),
169+
ResidualBottleneckBlock(self.N, self.N),
170+
deconv(self.N, 3, kernel_size=5, stride=2),
171+
)
172+
173+
def forward(self, x):
174+
x = self.dec(x)
175+
x = torch.sigmoid(x)
176+
return x
177+
178+
179+
#Define the Convolutional Autoencoder
180+
181+
class MCUCoder(L.LightningModule):
182+
def __init__(self, loss=None):
183+
super(MCUCoder, self).__init__()
184+
#
185+
self.encoder = Encoder()
186+
self.decoder = Decoder()
187+
self.loss_func = loss
188+
self.ms_ssim = MultiScaleStructuralSimilarityIndexMeasure(data_range=1.0)
189+
190+
self.p = None
191+
self.replace_value = 0
192+
193+
self.training_step_loss = []
194+
self.validation_step_psnr = []
195+
self.validation_step_ms_ssim = []
196+
self.saved_images=[]
197+
198+
def random_noise(self, x, r1, r2):
199+
temp_x = x.clone()
200+
noise = (r1 - r2) * torch.rand(x.shape) + r2
201+
return torch.clamp(temp_x + noise.cuda(), min=0.0, max=1.0)
202+
203+
204+
def forward(self, x):
205+
# Encoder
206+
x = self.encoder(x)
207+
x = self.rate_less(x)
208+
209+
# noise
210+
if not self.training:
211+
noise = torch.rand_like(x, dtype=torch.float) * 0.02 - 0.01
212+
x = x + noise.clone()
213+
214+
x = self.decoder(x)
215+
self.rec_image = x.detach().clone()
216+
217+
return x
218+
219+
def configure_optimizers(self):
220+
optimizer = torch.optim.Adam(self.parameters(), lr=1e-4, betas=(0.9, 0.999))
221+
first_phase = int(self.trainer.max_steps * 0.95)
222+
# first_phase = 200_000
223+
224+
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=first_phase, gamma=0.1)
225+
226+
return [optimizer], [{'scheduler': scheduler, 'interval': 'step'}]
227+
228+
229+
def training_step(self, train_batch, batch_idx):
230+
images = train_batch
231+
232+
outputs = self(images)
233+
234+
if self.loss_func =='msssim':
235+
loss = 1 - self.ms_ssim(outputs, images)
236+
self.log('during_train_loss_ms_ssim', loss, on_epoch=True, prog_bar=True, logger=True)
237+
238+
if self.loss_func =='mse':
239+
loss = nn.MSELoss()(outputs, images)
240+
self.log('during_train_loss_mse', loss, on_epoch=True, prog_bar=True, logger=True)
241+
242+
self.training_step_loss.append(loss)
243+
244+
return {
245+
'loss': loss
246+
}
247+
248+
def on_train_epoch_end(self):
249+
250+
loss = torch.stack([x for x in self.training_step_loss]).mean()
251+
self.log('train_loss_epoch', loss, on_epoch=True, prog_bar=True, logger=True)
252+
253+
self.training_step_loss.clear()
254+
255+
lr = self.trainer.lr_scheduler_configs[0].scheduler.get_last_lr()[0]
256+
self.log('learning_rate', lr, on_step=False, on_epoch=True)
257+
258+
def validation_step(self, val_batch, batch_idx):
259+
260+
self.saved_images = []
261+
msssim_temp = []
262+
psnr_temp = []
263+
264+
for t in [2, 6, 12]:
265+
self.p = t/12
266+
images = val_batch
267+
outputs = self(images)
268+
269+
psnr = psnr_batch(outputs, images)
270+
psnr_temp.append(psnr)
271+
272+
ms_ssim = ms_ssim_batch(outputs, images)
273+
msssim_temp.append(ms_ssim)
274+
275+
self.saved_images.append(self.rec_image[0])
276+
277+
278+
self.validation_step_psnr.append(psnr_temp)
279+
self.validation_step_ms_ssim.append(msssim_temp)
280+
281+
self.p = None
282+
return {'loss': psnr, 'ms_ssim': ms_ssim}
283+
284+
def on_validation_epoch_end(self):
285+
psnr = torch.stack([x[0] for x in self.validation_step_psnr]).mean()
286+
self.log('val_psnr_2l_epoch', psnr, on_epoch=True, prog_bar=True, logger=True)
287+
288+
psnr = torch.stack([x[1] for x in self.validation_step_psnr]).mean()
289+
self.log('val_psnr_6l_epoch', psnr, on_epoch=True, prog_bar=True, logger=True)
290+
291+
psnr = torch.stack([x[2] for x in self.validation_step_psnr]).mean()
292+
self.log('val_psnr_12l_epoch', psnr, on_epoch=True, prog_bar=True, logger=True)
293+
294+
295+
ms_ssim = torch.stack([x[0] for x in self.validation_step_ms_ssim]).mean()
296+
self.log('val_ms_ssim_2l_epoch', ms_ssim, on_epoch=True, prog_bar=True, logger=True)
297+
298+
ms_ssim = torch.stack([x[1] for x in self.validation_step_ms_ssim]).mean()
299+
self.log('val_ms_ssim_6l_epoch', ms_ssim, on_epoch=True, prog_bar=True, logger=True)
300+
301+
ms_ssim = torch.stack([x[2] for x in self.validation_step_ms_ssim]).mean()
302+
self.log('val_ms_ssim_12l_epoch', ms_ssim, on_epoch=True, prog_bar=True, logger=True)
303+
304+
305+
self.logger.experiment.log({"rec_image_2l": wandb.Image(self.saved_images[0])})
306+
self.logger.experiment.log({"rec_image_6l": wandb.Image(self.saved_images[1])})
307+
self.logger.experiment.log({"rec_image_12l": wandb.Image(self.saved_images[2])})
308+
309+
self.p = None
310+
self.validation_step_psnr.clear()
311+
self.validation_step_ms_ssim.clear()
312+
313+
def rate_less(self,x):
314+
temp_x = x.clone()
315+
if self.p:
316+
# p shows the percentage of keeping
317+
p = self.p
318+
else:
319+
p = np.random.randint(1, 13)/12
320+
321+
if p != 1.0:
322+
p = int(p * x.shape[1])
323+
replace_tensor = torch.rand(x.shape[0], x.shape[1]-p, x.shape[2], x.shape[3]).fill_(self.replace_value)
324+
temp_x[:,-(x.shape[1]-p):,:,:] = replace_tensor
325+
return temp_x

mcucoder/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from .data_preprocessing import *
2+
from .model import *
182 Bytes
Binary file not shown.
194 Bytes
Binary file not shown.
Binary file not shown.
Binary file not shown.
10.2 KB
Binary file not shown.
17.8 KB
Binary file not shown.

mcucoder/data_preprocessing.py

+27
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# Copyright 2024 Kiel University
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import lightning as L
7+
import torch
8+
from torchvision import datasets, transforms
9+
from torch.utils.data import Dataset
10+
import os
11+
from PIL import Image
12+
13+
class CustomImageDataset(Dataset):
14+
def __init__(self, root_dir, transform=None):
15+
self.root_dir = root_dir
16+
self.transform = transform
17+
self.image_files = [f for f in os.listdir(root_dir) if os.path.isfile(os.path.join(root_dir, f))]
18+
19+
def __len__(self):
20+
return len(self.image_files)
21+
22+
def __getitem__(self, idx):
23+
img_name = os.path.join(self.root_dir, self.image_files[idx])
24+
image = Image.open(img_name)
25+
if self.transform:
26+
image = self.transform(image)
27+
return image

0 commit comments

Comments
 (0)