Skip to content
This repository has been archived by the owner on Jan 2, 2021. It is now read-only.

Commit

Permalink
Added -seeds option to provide seeds from disk
Browse files Browse the repository at this point in the history
  • Loading branch information
dribnet committed Nov 6, 2016
1 parent 149ae4e commit 3965e4d
Showing 1 changed file with 72 additions and 12 deletions.
84 changes: 72 additions & 12 deletions enhance.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
add_arg('--rendering-overlap', default=32, type=int, help='Number of pixels padding around each tile.')
add_arg('--model', default='small', type=str, help='Name of the neural network to load/save.')
add_arg('--train', default=False, type=str, help='File pattern to load for training.')
add_arg('--seeds', default=False, type=str, help='File pattern to load for training seeds.')
add_arg('--train-blur', default=None, type=int, help='Sigma value for gaussian blur preprocess.')
add_arg('--train-noise', default=None, type=float, help='Radius for preprocessing gaussian blur.')
add_arg('--train-jpeg', default=None, type=int, help='JPEG compression level in preprocessing.')
Expand Down Expand Up @@ -125,6 +126,28 @@ def extend(lst): return itertools.chain(lst, itertools.repeat(lst[-1]))

print('{} - Using the device `{}` for neural computation.{}\n'.format(ansi.CYAN, theano.config.device, ansi.ENDC))

def confirm_pairs(list1, list2):
new_list1 = []
new_list2 = []
cur1 = 0
cur2 = 0
len1 = len(list1)
len2 = len(list2)
while(cur1 < len1 and cur2 < len2):
base1 = os.path.basename(list1[cur1])
base2 = os.path.basename(list2[cur2])
if base1 == base2:
new_list1.append(list1[cur1])
new_list2.append(list2[cur2])
cur1 = cur1 + 1
cur2 = cur2 + 1
elif base1 < base2:
# continue to look on list1, don't iterate list2
cur1 = cur1 + 1
else:
cur2 = cur2 + 1
print("List sizes went from {}, {} to {}, {}".format(len1, len2, len(new_list1), len(new_list2)))
return new_list1, new_list2

#======================================================================================================================
# Image Processing
Expand All @@ -140,7 +163,13 @@ def __init__(self):

self.orig_buffer = np.zeros((args.buffer_size, 3, self.orig_shape, self.orig_shape), dtype=np.float32)
self.seed_buffer = np.zeros((args.buffer_size, 3, self.seed_shape, self.seed_shape), dtype=np.float32)
self.files = glob.glob(args.train)
self.files = sorted(glob.glob(args.train))
if args.seeds:
self.seeds = sorted(glob.glob(args.seeds))
self.files, self.seeds = confirm_pairs(self.files, self.seeds)
else:
self.seeds = False

if len(self.files) == 0:
error("There were no files found to train from searching for `{}`".format(args.train),
" - Try putting all your images in one folder and using `--train=data/*.jpg`")
Expand All @@ -153,11 +182,14 @@ def __init__(self):

def run(self):
while True:
random.shuffle(self.files)
for f in self.files:
self.add_to_buffer(f)
indices = list(range(0, len(self.files)))
random.shuffle(indices)

for file_index in indices:
self.add_to_buffer(file_index)

def add_to_buffer(self, f):
def add_to_buffer(self, file_index):
f = self.files[file_index]
filename = os.path.join(self.cwd, f)
try:
orig = PIL.Image.open(filename).convert('RGB')
Expand All @@ -168,27 +200,55 @@ def add_to_buffer(self, f):
except Exception as e:
warn('Could not load `{}` as image.'.format(filename),
' - Try fixing or removing the file before next run.')
self.files.remove(f)
del self.files[file_index]
if self.seeds:
del self.seeds[file_index]
return

seed = orig.filter(PIL.ImageFilter.GaussianBlur(radius=args.train_blur)) if args.train_blur else orig
seed = seed.resize((orig.size[0]//args.zoom, orig.size[1]//args.zoom), resample=PIL.Image.LANCZOS)
# determine seed
if self.seeds:
# file based seed
f = self.seeds[file_index]
filename = os.path.join(self.cwd, f)
try:
seed = PIL.Image.open(filename).convert('RGB')
if any(s < self.seed_shape for s in seed.size):
raise ValueError('Image is too small for seed size (found {}, expected {})'.format(seed.size, self.seed_shape))
except Exception as e:
warn('Could not load `{}` as seed image.'.format(filename),
' - Try fixing or removing the file before next run. ({})'.format(e))
del self.files[file_index]
del self.seeds[file_index]
return
else:
# synthetic seed
# optionally blur before scaling
seed = orig.filter(PIL.ImageFilter.GaussianBlur(radius=args.train_blur)) if args.train_blur else orig
# seed is scaled down version of original
seed = seed.resize((orig.size[0]//args.zoom, orig.size[1]//args.zoom), resample=PIL.Image.LANCZOS)

# convert original to numpy float32
orig = scipy.misc.fromimage(orig).astype(np.float32)

if args.train_jpeg:
# optionally convert to jpeg
if args.train_jpeg is not None:
buffer = io.BytesIO()
seed.save(buffer, format='jpeg', quality=args.train_jpeg+random.randrange(-15,+15))
seed = PIL.Image.open(buffer)

# convert seed to numpy float32
seed = scipy.misc.fromimage(seed, mode='RGB').astype(np.float32)
seed += scipy.random.normal(scale=args.train_noise, size=(seed.shape[0], seed.shape[1], 1))\
if args.train_noise else 0.0

orig = scipy.misc.fromimage(orig).astype(np.float32)
# optionally add training noise to seed
if args.train_noise is not None:
seed += scipy.random.normal(scale=args.train_noise, size=(seed.shape[0], seed.shape[1], 1))

for _ in range(args.buffer_similar):
# compute seed displacement
h = random.randint(0, seed.shape[0] - self.seed_shape)
w = random.randint(0, seed.shape[1] - self.seed_shape)
seed_chunk = seed[h:h+self.seed_shape, w:w+self.seed_shape]
# and matching image displacement
h, w = h * args.zoom, w * args.zoom
orig_chunk = orig[h:h+self.orig_shape, w:w+self.orig_shape]

Expand Down

0 comments on commit 3965e4d

Please sign in to comment.