|
| 1 | +import numpy as np |
| 2 | +import os |
| 3 | +from torch.utils.data import Dataset |
| 4 | +import torch |
| 5 | + |
| 6 | +from utils.image_utils import is_numpy_file, load_npy, pack_raw, load_dict |
| 7 | +from utils.dataset_utils import Augment_Bayer, bayer_unify |
| 8 | + |
| 9 | +augment = Augment_Bayer() |
| 10 | +transforms_aug = [method for method in dir(augment) if callable(getattr(augment, method)) if not method.startswith('_')] |
| 11 | + |
| 12 | + |
| 13 | +class DataLoaderTrain(Dataset): |
| 14 | + def __init__(self, raw_dir, rgb_dir, img_options=None): |
| 15 | + super(DataLoaderTrain, self).__init__() |
| 16 | + |
| 17 | + self.pkl_bayer_patterns = load_dict('./datasets/fivek_bayer.pkl') |
| 18 | + |
| 19 | + rgb_files=sorted(os.listdir(rgb_dir)) |
| 20 | + raw_files=sorted(os.listdir(raw_dir)) |
| 21 | + |
| 22 | + self.rgb_filenames = [os.path.join(rgb_dir, x) for x in rgb_files if is_numpy_file(x)] |
| 23 | + self.raw_filenames = [os.path.join(raw_dir, x) for x in raw_files if is_numpy_file(x)] |
| 24 | + |
| 25 | + self.img_options=img_options |
| 26 | + self.rgb_size = len(self.rgb_filenames) # get the size of input |
| 27 | + self.raw_size = len(self.raw_filenames) # get the size of target |
| 28 | + |
| 29 | + def __len__(self): |
| 30 | + return max(self.rgb_size, self.raw_size) |
| 31 | + |
| 32 | + def __getitem__(self, index): |
| 33 | + rgb_index = index % self.rgb_size |
| 34 | + raw_index = index % self.raw_size |
| 35 | + |
| 36 | + filename = os.path.splitext(os.path.split(self.rgb_filenames[rgb_index])[-1])[0] |
| 37 | + bayer_pattern = self.pkl_bayer_patterns[filename] |
| 38 | + |
| 39 | + ## Load Images |
| 40 | + rgb_image = load_npy(self.rgb_filenames[rgb_index]) |
| 41 | + raw_image = load_npy(self.raw_filenames[raw_index]) |
| 42 | + |
| 43 | + |
| 44 | + #Extract random crops from rgb and raw images |
| 45 | + ps = self.img_options['patch_size'] |
| 46 | + ps_temp = ps*2 + 16 |
| 47 | + H = raw_image.shape[0] |
| 48 | + W = raw_image.shape[1] |
| 49 | + r = np.random.randint(0, H - ps_temp) |
| 50 | + c = np.random.randint(0, W - ps_temp) |
| 51 | + if r%2!=0: r = r-1 |
| 52 | + if c%2!=0: c = c-1 |
| 53 | + rgb_patch = rgb_image[r:r + ps_temp, c:c + ps_temp, :] |
| 54 | + raw_patch = raw_image[r:r + ps_temp, c:c + ps_temp, :] |
| 55 | + |
| 56 | + |
| 57 | + raw_patch, rgb_patch = bayer_unify(raw_patch.squeeze(), rgb_patch, bayer_pattern, "RGGB", "crop") |
| 58 | + |
| 59 | + #Apply Bayer Augmentation |
| 60 | + indx = np.random.randint(0,len(transforms_aug)) |
| 61 | + apply_trans = transforms_aug[indx] |
| 62 | + |
| 63 | + raw_patch, rgb_patch = getattr(augment, apply_trans)(raw_patch[...,np.newaxis], rgb_patch) |
| 64 | + |
| 65 | + #Pack Target |
| 66 | + raw_patch = pack_raw(raw_patch) |
| 67 | + |
| 68 | + # Extract crops of desired patch size |
| 69 | + H = raw_patch.shape[0] |
| 70 | + W = raw_patch.shape[1] |
| 71 | + r = (H - ps) // 2 |
| 72 | + c = (W - ps) // 2 |
| 73 | + PS, R, C = ps*2, r*2, c*2 |
| 74 | + rgb_patch = rgb_patch[R:R + PS, C:C + PS, :] |
| 75 | + raw_patch = raw_patch[r:r + ps, c:c + ps, :] |
| 76 | + |
| 77 | + rgb_patch = torch.Tensor(rgb_patch).permute(2,0,1) |
| 78 | + raw_patch = torch.Tensor(raw_patch).permute(2,0,1) |
| 79 | + |
| 80 | + return rgb_patch,raw_patch |
| 81 | + |
| 82 | + |
0 commit comments