Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parallel validator and migrations #60

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 57 additions & 25 deletions trainer/diffusers_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import re
import traceback
import shutil
from concurrent.futures import ThreadPoolExecutor

try:
pynvml.nvmlInit()
Expand Down Expand Up @@ -169,24 +170,23 @@ def __init__(self, is_skipped: bool, is_extended: bool) -> None:

def __validate(self, fp: str) -> bool:
try:
Image.open(fp)
Image.open(fp).close()
return True
except:
print(f'WARNING: Image cannot be opened: {fp}')
tqdm.tqdm.write(f'WARNING: Image cannot be opened: {fp}')
return False

def __extended_validate(self, fp: str) -> bool:
try:
Image.open(fp).load()
im = Image.open(fp)
im.load()
im.close()
return True
except (OSError) as error:
if 'truncated' in str(error):
print(f'WARNING: Image truncated: {error}')
return False
print(f'WARNING: Image cannot be opened: {error}')
tqdm.tqdm.write(f'WARNING: {error}: {fp}')
return False
except:
print(f'WARNING: Image cannot be opened: {error}')
tqdm.tqdm.write(f'WARNING: Image cannot be opened: {fp}')
return False

def __no_op(self, fp: str) -> bool:
Expand Down Expand Up @@ -221,19 +221,32 @@ def __no_migration(self, image_path: str, w: int, h: int) -> Img:

def __migration(self, image_path: str, w: int, h: int) -> Img:
filename = re.sub('\.[^/.]+$', '', os.path.split(image_path)[1])

image = ImageOps.fit(
Image.open(image_path),
(w, h),
bleed=0.0,
centering=(0.5, 0.5),
method=Image.Resampling.LANCZOS
).convert(mode='RGB')

image.save(
os.path.join(f'{self.__directory}', f'{filename}.jpg'),
optimize=True
)
image = Image.open(image_path)
needs_update = False
if image.size != (w, h):
image = ImageOps.fit(
image,
(w, h),
bleed=0.0,
centering=(0.5, 0.5),
method=Image.Resampling.LANCZOS
)
needs_update = True
if image.mode != 'RGB':
image = image.convert(mode='RGB')
needs_update = True
_,ext = os.path.splitext(image_path)
if needs_update:
image.save(
os.path.join(f'{self.__directory}', f'{filename}{ext}'),
optimize=True, quality=100 if ext == '.webp' else 95
)
else:
shutil.copy(
image_path,
os.path.join(f'{self.__directory}', f'{filename}{ext}'),
follow_symlinks=False
)

try:
shutil.copy(
Expand Down Expand Up @@ -268,7 +281,18 @@ def __init__(self, data_dir: str) -> None:

self.resizer = Resize(args.resize, args.no_migration).resize

self.image_files = [x for x in self.image_files if self.validator(x)]
pool = ThreadPoolExecutor()
futures = []
for x in self.image_files:
futures.append(pool.submit(self.validator, x))
self.image_files = [
x[1] for x in tqdm.tqdm(
zip(futures, self.image_files),
total=len(self.image_files),
desc='Validating', dynamic_ncols=True)
if x[0].result()
]
pool.shutdown(wait=True)

def __len__(self) -> int:
return len(self.image_files)
Expand Down Expand Up @@ -832,9 +856,17 @@ def main():

# Migrate dataset
if args.resize and not args.no_migration:
for _, batch in enumerate(train_dataloader):
continue
print(f"Completed resize and migration to '{args.dataset}_cropped' please relaunch the trainer without the --resize argument and train on the migrated dataset.")
pool = ThreadPoolExecutor()
pbar = tqdm.tqdm(total=len(dataset), desc='Migrating', dynamic_ncols=True)
def _callback(x):
x.result().close()
pbar.update(1)
for b in sampler:
for idx, w, h in b:
f = pool.submit(store.get_image, (idx, w, h))
f.add_done_callback(_callback)
pool.shutdown(wait = True)
tqdm.tqdm.write(f"Completed resize and migration to '{args.dataset}_cropped' please relaunch the trainer without the --resize argument and train on the migrated dataset.")
exit(0)

# create ema
Expand Down