From 3965e4d0a4d8a82f74a9caf933bc9aa04b6d4af5 Mon Sep 17 00:00:00 2001 From: Tom White Date: Sun, 6 Nov 2016 17:04:11 +1300 Subject: [PATCH] Added -seeds option to provide seeds from disk --- enhance.py | 84 ++++++++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 72 insertions(+), 12 deletions(-) diff --git a/enhance.py b/enhance.py index bd1e587..ac54d36 100755 --- a/enhance.py +++ b/enhance.py @@ -41,6 +41,7 @@ add_arg('--rendering-overlap', default=32, type=int, help='Number of pixels padding around each tile.') add_arg('--model', default='small', type=str, help='Name of the neural network to load/save.') add_arg('--train', default=False, type=str, help='File pattern to load for training.') +add_arg('--seeds', default=False, type=str, help='File pattern to load for training seeds.') add_arg('--train-blur', default=None, type=int, help='Sigma value for gaussian blur preprocess.') add_arg('--train-noise', default=None, type=float, help='Radius for preprocessing gaussian blur.') add_arg('--train-jpeg', default=None, type=int, help='JPEG compression level in preprocessing.') @@ -125,6 +126,28 @@ def extend(lst): return itertools.chain(lst, itertools.repeat(lst[-1])) print('{} - Using the device `{}` for neural computation.{}\n'.format(ansi.CYAN, theano.config.device, ansi.ENDC)) +def confirm_pairs(list1, list2): + new_list1 = [] + new_list2 = [] + cur1 = 0 + cur2 = 0 + len1 = len(list1) + len2 = len(list2) + while(cur1 < len1 and cur2 < len2): + base1 = os.path.basename(list1[cur1]) + base2 = os.path.basename(list2[cur2]) + if base1 == base2: + new_list1.append(list1[cur1]) + new_list2.append(list2[cur2]) + cur1 = cur1 + 1 + cur2 = cur2 + 1 + elif base1 < base2: + # continue to look on list1, don't iterate list2 + cur1 = cur1 + 1 + else: + cur2 = cur2 + 1 + print("List sizes went from {}, {} to {}, {}".format(len1, len2, len(new_list1), len(new_list2))) + return new_list1, new_list2 #====================================================================================================================== # Image Processing @@ -140,7 +163,13 @@ def __init__(self): self.orig_buffer = np.zeros((args.buffer_size, 3, self.orig_shape, self.orig_shape), dtype=np.float32) self.seed_buffer = np.zeros((args.buffer_size, 3, self.seed_shape, self.seed_shape), dtype=np.float32) - self.files = glob.glob(args.train) + self.files = sorted(glob.glob(args.train)) + if args.seeds: + self.seeds = sorted(glob.glob(args.seeds)) + self.files, self.seeds = confirm_pairs(self.files, self.seeds) + else: + self.seeds = False + if len(self.files) == 0: error("There were no files found to train from searching for `{}`".format(args.train), " - Try putting all your images in one folder and using `--train=data/*.jpg`") @@ -153,11 +182,14 @@ def __init__(self): def run(self): while True: - random.shuffle(self.files) - for f in self.files: - self.add_to_buffer(f) + indices = list(range(0, len(self.files))) + random.shuffle(indices) + + for file_index in indices: + self.add_to_buffer(file_index) - def add_to_buffer(self, f): + def add_to_buffer(self, file_index): + f = self.files[file_index] filename = os.path.join(self.cwd, f) try: orig = PIL.Image.open(filename).convert('RGB') @@ -168,27 +200,55 @@ def add_to_buffer(self, f): except Exception as e: warn('Could not load `{}` as image.'.format(filename), ' - Try fixing or removing the file before next run.') - self.files.remove(f) + del self.files[file_index] + if self.seeds: + del self.seeds[file_index] return - seed = orig.filter(PIL.ImageFilter.GaussianBlur(radius=args.train_blur)) if args.train_blur else orig - seed = seed.resize((orig.size[0]//args.zoom, orig.size[1]//args.zoom), resample=PIL.Image.LANCZOS) + # determine seed + if self.seeds: + # file based seed + f = self.seeds[file_index] + filename = os.path.join(self.cwd, f) + try: + seed = PIL.Image.open(filename).convert('RGB') + if any(s < self.seed_shape for s in seed.size): + raise ValueError('Image is too small for seed size (found {}, expected {})'.format(seed.size, self.seed_shape)) + except Exception as e: + warn('Could not load `{}` as seed image.'.format(filename), + ' - Try fixing or removing the file before next run. ({})'.format(e)) + del self.files[file_index] + del self.seeds[file_index] + return + else: + # synthetic seed + # optionally blur before scaling + seed = orig.filter(PIL.ImageFilter.GaussianBlur(radius=args.train_blur)) if args.train_blur else orig + # seed is scaled down version of original + seed = seed.resize((orig.size[0]//args.zoom, orig.size[1]//args.zoom), resample=PIL.Image.LANCZOS) + + # convert original to numpy float32 + orig = scipy.misc.fromimage(orig).astype(np.float32) - if args.train_jpeg: + # optionally convert to jpeg + if args.train_jpeg is not None: buffer = io.BytesIO() seed.save(buffer, format='jpeg', quality=args.train_jpeg+random.randrange(-15,+15)) seed = PIL.Image.open(buffer) + # convert seed to numpy float32 seed = scipy.misc.fromimage(seed, mode='RGB').astype(np.float32) - seed += scipy.random.normal(scale=args.train_noise, size=(seed.shape[0], seed.shape[1], 1))\ - if args.train_noise else 0.0 - orig = scipy.misc.fromimage(orig).astype(np.float32) + # optionally add training noise to seed + if args.train_noise is not None: + seed += scipy.random.normal(scale=args.train_noise, size=(seed.shape[0], seed.shape[1], 1)) for _ in range(args.buffer_similar): + # compute seed displacement h = random.randint(0, seed.shape[0] - self.seed_shape) w = random.randint(0, seed.shape[1] - self.seed_shape) seed_chunk = seed[h:h+self.seed_shape, w:w+self.seed_shape] + # and matching image displacement h, w = h * args.zoom, w * args.zoom orig_chunk = orig[h:h+self.orig_shape, w:w+self.orig_shape]