Skip to content

Commit 0acf7d3

Browse files
authoredApr 7, 2020
dataloader for MIT-Adobe fivek dataset
1 parent 11659a1 commit 0acf7d3

File tree

1 file changed

+82
-0
lines changed

1 file changed

+82
-0
lines changed
 

‎dataloaders/dataset_rgb2raw.py

+82
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
import numpy as np
2+
import os
3+
from torch.utils.data import Dataset
4+
import torch
5+
6+
from utils.image_utils import is_numpy_file, load_npy, pack_raw, load_dict
7+
from utils.dataset_utils import Augment_Bayer, bayer_unify
8+
9+
augment = Augment_Bayer()
10+
transforms_aug = [method for method in dir(augment) if callable(getattr(augment, method)) if not method.startswith('_')]
11+
12+
13+
class DataLoaderTrain(Dataset):
14+
def __init__(self, raw_dir, rgb_dir, img_options=None):
15+
super(DataLoaderTrain, self).__init__()
16+
17+
self.pkl_bayer_patterns = load_dict('./datasets/fivek_bayer.pkl')
18+
19+
rgb_files=sorted(os.listdir(rgb_dir))
20+
raw_files=sorted(os.listdir(raw_dir))
21+
22+
self.rgb_filenames = [os.path.join(rgb_dir, x) for x in rgb_files if is_numpy_file(x)]
23+
self.raw_filenames = [os.path.join(raw_dir, x) for x in raw_files if is_numpy_file(x)]
24+
25+
self.img_options=img_options
26+
self.rgb_size = len(self.rgb_filenames) # get the size of input
27+
self.raw_size = len(self.raw_filenames) # get the size of target
28+
29+
def __len__(self):
30+
return max(self.rgb_size, self.raw_size)
31+
32+
def __getitem__(self, index):
33+
rgb_index = index % self.rgb_size
34+
raw_index = index % self.raw_size
35+
36+
filename = os.path.splitext(os.path.split(self.rgb_filenames[rgb_index])[-1])[0]
37+
bayer_pattern = self.pkl_bayer_patterns[filename]
38+
39+
## Load Images
40+
rgb_image = load_npy(self.rgb_filenames[rgb_index])
41+
raw_image = load_npy(self.raw_filenames[raw_index])
42+
43+
44+
#Extract random crops from rgb and raw images
45+
ps = self.img_options['patch_size']
46+
ps_temp = ps*2 + 16
47+
H = raw_image.shape[0]
48+
W = raw_image.shape[1]
49+
r = np.random.randint(0, H - ps_temp)
50+
c = np.random.randint(0, W - ps_temp)
51+
if r%2!=0: r = r-1
52+
if c%2!=0: c = c-1
53+
rgb_patch = rgb_image[r:r + ps_temp, c:c + ps_temp, :]
54+
raw_patch = raw_image[r:r + ps_temp, c:c + ps_temp, :]
55+
56+
57+
raw_patch, rgb_patch = bayer_unify(raw_patch.squeeze(), rgb_patch, bayer_pattern, "RGGB", "crop")
58+
59+
#Apply Bayer Augmentation
60+
indx = np.random.randint(0,len(transforms_aug))
61+
apply_trans = transforms_aug[indx]
62+
63+
raw_patch, rgb_patch = getattr(augment, apply_trans)(raw_patch[...,np.newaxis], rgb_patch)
64+
65+
#Pack Target
66+
raw_patch = pack_raw(raw_patch)
67+
68+
# Extract crops of desired patch size
69+
H = raw_patch.shape[0]
70+
W = raw_patch.shape[1]
71+
r = (H - ps) // 2
72+
c = (W - ps) // 2
73+
PS, R, C = ps*2, r*2, c*2
74+
rgb_patch = rgb_patch[R:R + PS, C:C + PS, :]
75+
raw_patch = raw_patch[r:r + ps, c:c + ps, :]
76+
77+
rgb_patch = torch.Tensor(rgb_patch).permute(2,0,1)
78+
raw_patch = torch.Tensor(raw_patch).permute(2,0,1)
79+
80+
return rgb_patch,raw_patch
81+
82+

0 commit comments

Comments
 (0)