Skip to content

Commit

Permalink
Merge branch 'sdxl' of https://github.com/kohya-ss/sd-scripts into dev2
Browse files Browse the repository at this point in the history
  • Loading branch information
bmaltais committed Jul 23, 2023
2 parents e3426eb + b1e44e9 commit 46638c3
Show file tree
Hide file tree
Showing 8 changed files with 85 additions and 108 deletions.
2 changes: 1 addition & 1 deletion docs/train_README-ja.md
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ Stable Diffusion のv1は512\*512で学習されていますが、それに加

また任意の解像度で学習するため、事前に画像データの縦横比を統一しておく必要がなくなります。

設定で有効、向こうが切り替えられますが、ここまでの設定ファイルの記述例では有効になっています(`true` が設定されています)。
設定で有効、無効が切り替えられますが、ここまでの設定ファイルの記述例では有効になっています(`true` が設定されています)。

学習解像度はパラメータとして与えられた解像度の面積(=メモリ使用量)を超えない範囲で、64ピクセル単位(デフォルト、変更可)で縦横に調整、作成されます。

Expand Down
109 changes: 55 additions & 54 deletions docs/train_README-zh.md

Large diffs are not rendered by default.

4 changes: 3 additions & 1 deletion library/sdxl_model_util.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
from safetensors.torch import load_file, save_file
from transformers import CLIPTextModel, CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer
from diffusers import AutoencoderKL, EulerDiscreteScheduler, StableDiffusionXLPipeline, UNet2DConditionModel
from diffusers import AutoencoderKL, EulerDiscreteScheduler, UNet2DConditionModel
from library import model_util
from library import sdxl_original_unet

Expand Down Expand Up @@ -486,6 +486,8 @@ def update_sd(prefix, sd):
def save_diffusers_checkpoint(
output_dir, text_encoder1, text_encoder2, unet, pretrained_model_name_or_path, vae=None, use_safetensors=False, save_dtype=None
):
from diffusers import StableDiffusionXLPipeline

# convert U-Net
unet_sd = unet.state_dict()
du_unet_sd = convert_sdxl_unet_state_dict_to_diffusers(unet_sd)
Expand Down
23 changes: 13 additions & 10 deletions library/sdxl_train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,10 @@
import gc
import math
import os
from types import SimpleNamespace
from typing import Any
from typing import Optional
import torch
from tqdm import tqdm
from transformers import CLIPTokenizer
import open_clip
from diffusers import StableDiffusionXLPipeline
from library import model_util, sdxl_model_util, train_util, sdxl_original_unet
from library.sdxl_lpw_stable_diffusion import SdxlStableDiffusionLongPromptWeightingPipeline

Expand All @@ -18,7 +15,6 @@
DEFAULT_NOISE_OFFSET = 0.0357


# TODO: separate checkpoint for each U-Net/Text Encoder/VAE
def load_target_model(args, accelerator, model_version: str, weight_dtype):
# load models for each process
for pi in range(accelerator.state.num_processes):
Expand All @@ -33,7 +29,13 @@ def load_target_model(args, accelerator, model_version: str, weight_dtype):
unet,
logit_scale,
ckpt_info,
) = _load_target_model(args, model_version, weight_dtype, accelerator.device if args.lowram else "cpu")
) = _load_target_model(
args.pretrained_model_name_or_path,
args.vae,
model_version,
weight_dtype,
accelerator.device if args.lowram else "cpu",
)

# work on low-ram device
if args.lowram:
Expand All @@ -51,8 +53,7 @@ def load_target_model(args, accelerator, model_version: str, weight_dtype):
return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, logit_scale, ckpt_info


def _load_target_model(args: argparse.Namespace, model_version: str, weight_dtype, device="cpu"):
name_or_path = args.pretrained_model_name_or_path
def _load_target_model(name_or_path: str, vae_path: Optional[str], model_version: str, weight_dtype, device="cpu"):
name_or_path = os.readlink(name_or_path) if os.path.islink(name_or_path) else name_or_path
load_stable_diffusion_format = os.path.isfile(name_or_path) # determine SD or Diffusers

Expand All @@ -68,6 +69,8 @@ def _load_target_model(args: argparse.Namespace, model_version: str, weight_dtyp
) = sdxl_model_util.load_models_from_sdxl_checkpoint(model_version, name_or_path, device)
else:
# Diffusers model is loaded to CPU
from diffusers import StableDiffusionXLPipeline

variant = "fp16" if weight_dtype == torch.float16 else None
print(f"load Diffusers pretrained models: {name_or_path}, variant={variant}")
try:
Expand Down Expand Up @@ -102,8 +105,8 @@ def _load_target_model(args: argparse.Namespace, model_version: str, weight_dtyp
ckpt_info = None

# VAEを読み込む
if args.vae is not None:
vae = model_util.load_vae(args.vae, weight_dtype)
if vae_path is not None:
vae = model_util.load_vae(vae_path, weight_dtype)
print("additional VAE loaded")

return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, logit_scale, ckpt_info
Expand Down
7 changes: 3 additions & 4 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@
from library.lpw_stable_diffusion import StableDiffusionLongPromptWeightingPipeline
import library.model_util as model_util
import library.huggingface_util as huggingface_util
from library.attention_processors import FlashAttnProcessor
from library.hypernetwork import replace_attentions_for_hypernetwork
# from library.attention_processors import FlashAttnProcessor
# from library.hypernetwork import replace_attentions_for_hypernetwork
from library.original_unet import UNet2DConditionModel

# Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う
Expand Down Expand Up @@ -1884,8 +1884,7 @@ def load_latents_from_disk(
) -> Tuple[Optional[torch.Tensor], Optional[List[int]], Optional[List[int]], Optional[torch.Tensor]]:
npz = np.load(npz_path)
if "latents" not in npz:
print(f"error: npz is old format. please re-generate {npz_path}")
return None, None, None, None
raise ValueError(f"error: npz is old format. please re-generate {npz_path}")

latents = npz["latents"]
original_size = npz["original_size"].tolist()
Expand Down
2 changes: 1 addition & 1 deletion networks/sdxl_merge_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def str_to_dtype(p):
ckpt_info,
) = sdxl_model_util.load_models_from_sdxl_checkpoint(sdxl_model_util.MODEL_VERSION_SDXL_BASE_V0_9, args.sd_model, "cpu")

merge_to_sd_model(text_model2, text_model2, unet, args.models, args.ratios, merge_dtype)
merge_to_sd_model(text_model1, text_model2, unet, args.models, args.ratios, merge_dtype)

print(f"saving SD model to: {args.save_to}")
sdxl_model_util.save_stable_diffusion_checkpoint(
Expand Down
38 changes: 4 additions & 34 deletions sdxl_gen_img.py
Original file line number Diff line number Diff line change
Expand Up @@ -958,7 +958,7 @@ def get_unweighted_text_embeddings(
enc_out = text_encoder(text_input_chunk, output_hidden_states=True, return_dict=True)
text_embedding = enc_out["hidden_states"][-2]
if pool is None:
pool = enc_out["text_embeds"] # use 1st chunk
pool = enc_out.get("text_embeds", None) # use 1st chunk, if provided

if no_boseos_middle:
if i == 0:
Expand Down Expand Up @@ -1288,38 +1288,9 @@ def main(args):
if len(files) == 1:
args.ckpt = files[0]

use_stable_diffusion_format = os.path.isfile(args.ckpt)
assert use_stable_diffusion_format, "Diffusers pretrained models are not supported yet"
print("load StableDiffusion checkpoint")
text_encoder1, text_encoder2, vae, unet, _, _ = sdxl_model_util.load_models_from_sdxl_checkpoint(
sdxl_model_util.MODEL_VERSION_SDXL_BASE_V0_9, args.ckpt, "cpu"
(_, text_encoder1, text_encoder2, vae, unet, _, _) = sdxl_train_util._load_target_model(
args.ckpt, args.vae, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V0_9, dtype
)
# else:
# print("load Diffusers pretrained models")
# TODO use Diffusers 0.18.1 and support SDXL pipeline
# raise NotImplementedError("Diffusers pretrained models are not supported yet")
# loading_pipe = StableDiffusionXLPipeline.from_pretrained(args.ckpt, safety_checker=None, torch_dtype=dtype)
# text_encoder = loading_pipe.text_encoder
# vae = loading_pipe.vae
# unet = loading_pipe.unet
# tokenizer = loading_pipe.tokenizer
# del loading_pipe

# # Diffusers U-Net to original U-Net
# original_unet = SdxlUNet2DConditionModel(
# unet.config.sample_size,
# unet.config.attention_head_dim,
# unet.config.cross_attention_dim,
# unet.config.use_linear_projection,
# unet.config.upcast_attention,
# )
# original_unet.load_state_dict(unet.state_dict())
# unet = original_unet

# VAEを読み込む
if args.vae is not None:
vae = model_util.load_vae(args.vae, dtype)
print("additional VAE loaded")

# xformers、Hypernetwork対応
if not args.diffusers_xformers:
Expand All @@ -1329,8 +1300,7 @@ def main(args):

# tokenizerを読み込む
print("loading tokenizer")
if use_stable_diffusion_format:
tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args)
tokenizer1, tokenizer2 = sdxl_train_util.load_tokenizers(args)

# schedulerを用意する
sched_init_args = {}
Expand Down
8 changes: 5 additions & 3 deletions sdxl_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,15 +333,17 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
args.save_every_n_epochs = math.floor(num_train_epochs / args.save_n_epoch_ratio) or 1

# 学習する
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
# total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
accelerator.print("running training / 学習開始")
accelerator.print(f" num examples / サンプル数: {train_dataset_group.num_train_images}")
accelerator.print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
accelerator.print(f" num epochs / epoch数: {num_train_epochs}")
accelerator.print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
accelerator.print(
f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}"
f" batch size per device / バッチサイズ: {', '.join([str(d.batch_size) for d in train_dataset_group.datasets])}"
)
# accelerator.print(
# f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}"
# )
accelerator.print(f" gradient accumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
accelerator.print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")

Expand Down

0 comments on commit 46638c3

Please sign in to comment.