Skip to content
This repository has been archived by the owner on Jan 2, 2021. It is now read-only.

Seeds option #33

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 66 additions & 11 deletions enhance.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
add_arg('--train-blur', default=None, type=int, help='Sigma value for gaussian blur preprocess.')
add_arg('--train-noise', default=None, type=float, help='Radius for preprocessing gaussian blur.')
add_arg('--train-jpeg', default=[], nargs='+', type=int, help='JPEG compression level & range in preproc.')
add_arg('--seeds', default=False, type=str, help='File pattern to load for training seeds.')
add_arg('--epochs', default=10, type=int, help='Total number of iterations in training.')
add_arg('--epoch-size', default=72, type=int, help='Number of batches trained in an epoch.')
add_arg('--save-every', default=10, type=int, help='Save generator after every training epoch.')
Expand Down Expand Up @@ -128,6 +129,28 @@ def extend(lst): return itertools.chain(lst, itertools.repeat(lst[-1]))

print('{} - Using the device `{}` for neural computation.{}\n'.format(ansi.CYAN, theano.config.device, ansi.ENDC))

def confirm_pairs(list1, list2):
new_list1 = []
new_list2 = []
cur1 = 0
cur2 = 0
len1 = len(list1)
len2 = len(list2)
while(cur1 < len1 and cur2 < len2):
base1 = os.path.basename(list1[cur1])
base2 = os.path.basename(list2[cur2])
if base1 == base2:
new_list1.append(list1[cur1])
new_list2.append(list2[cur2])
cur1 = cur1 + 1
cur2 = cur2 + 1
elif base1 < base2:
# continue to look on list1, don't iterate list2
cur1 = cur1 + 1
else:
cur2 = cur2 + 1
print("List sizes went from {}, {} to {}, {}".format(len1, len2, len(new_list1), len(new_list2)))
return new_list1, new_list2

#======================================================================================================================
# Image Processing
Expand All @@ -143,11 +166,17 @@ def __init__(self):

self.orig_buffer = np.zeros((args.buffer_size, 3, self.orig_shape, self.orig_shape), dtype=np.float32)
self.seed_buffer = np.zeros((args.buffer_size, 3, self.seed_shape, self.seed_shape), dtype=np.float32)
self.files = glob.glob(args.train)
self.files = sorted(glob.glob(args.train))
if len(self.files) == 0:
error("There were no files found to train from searching for `{}`".format(args.train),
" - Try putting all your images in one folder and using `--train=data/*.jpg`")

if args.seeds:
self.seeds = sorted(glob.glob(args.seeds))
self.files, self.seeds = confirm_pairs(self.files, self.seeds)
else:
self.seeds = False

self.available = set(range(args.buffer_size))
self.ready = set()

Expand All @@ -156,11 +185,14 @@ def __init__(self):

def run(self):
while True:
random.shuffle(self.files)
for f in self.files:
self.add_to_buffer(f)
indices = list(range(0, len(self.files)))
random.shuffle(indices)

for file_index in indices:
self.add_to_buffer(file_index)

def add_to_buffer(self, f):
def add_to_buffer(self, file_index):
f = self.files[file_index]
filename = os.path.join(self.cwd, f)
try:
orig = PIL.Image.open(filename).convert('RGB')
Expand All @@ -172,14 +204,37 @@ def add_to_buffer(self, f):
except Exception as e:
warn('Could not load `{}` as image.'.format(filename),
' - Try fixing or removing the file before next run.')
self.files.remove(f)
del self.files[file_index]
if self.seeds:
del self.seeds[file_index]
return

seed = orig
if args.train_blur is not None:
seed = seed.filter(PIL.ImageFilter.GaussianBlur(radius=random.randint(0, args.train_blur*2)))
if args.zoom > 1:
seed = seed.resize((orig.size[0]//args.zoom, orig.size[1]//args.zoom), resample=PIL.Image.LANCZOS)
# determine seed
if self.seeds:
# file based seed
f = self.seeds[file_index]
filename = os.path.join(self.cwd, f)
try:
seed = PIL.Image.open(filename).convert('RGB')
if any(s < self.seed_shape for s in seed.size):
raise ValueError('Image is too small for seed size (found {}, expected {})'.format(seed.size, self.seed_shape))
except Exception as e:
warn('Could not load `{}` as seed image.'.format(filename),
' - Try fixing or removing the file before next run. ({})'.format(e))
del self.files[file_index]
del self.seeds[file_index]
return

else:
# synthetic seed
seed = orig
# optionally blur before scaling
if args.train_blur is not None:
seed = seed.filter(PIL.ImageFilter.GaussianBlur(radius=random.randint(0, args.train_blur*2)))
# seed is scaled down version of original
if args.zoom > 1:
seed = seed.resize((orig.size[0]//args.zoom, orig.size[1]//args.zoom), resample=PIL.Image.LANCZOS)

if len(args.train_jpeg) > 0:
buffer, rng = io.BytesIO(), args.train_jpeg[-1] if len(args.train_jpeg) > 1 else 15
seed.save(buffer, format='jpeg', quality=args.train_jpeg+random.randrange(-rng, +rng))
Expand Down