Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

w_embedding = guidance_scale_embedding is missing in the train_lcm_distill_sdxl_wds.py #81

Open
PetrByvsh opened this issue Dec 29, 2023 · 2 comments

Comments

@PetrByvsh
Copy link

PetrByvsh commented Dec 29, 2023

In the train_lcm_distill_sd_wds.py
# 20.4.6. Sample a random guidance scale w from U[w_min, w_max] and embed it
w = (args.w_max - args.w_min) * torch.rand((bsz,)) + args.w_min
w_embedding = guidance_scale_embedding(w, embedding_dim=args.unet_time_cond_proj_dim)
w = w.reshape(bsz, 1, 1, 1)
# Move to U-Net device and dtype
w = w.to(device=latents.device, dtype=latents.dtype)
w_embedding = w_embedding.to(device=latents.device, dtype=latents.dtype)

train_lcm_distill_sdxl_wds.py :
w = (args.w_max - args.w_min) * torch.rand((bsz,)) + args.w_min
w = w.reshape(bsz, 1, 1, 1)
w = w.to(device=latents.device, dtype=latents.dtype)

Any reason for this? The code for XL model does not work without it (it defines
noise_pred = unet(
noisy_model_input,
start_timesteps,
timestep_cond=None,
encoder_hidden_states=prompt_embeds.float(),
added_cond_kwargs=encoded_text,
).sample
The timestep_cond is None, although the "unet_time_cond_proj_dim" is still required as raised in the other issue.

@PetrByvsh PetrByvsh reopened this Dec 29, 2023
@shuminghu
Copy link

If you use the latest code from diffuser, the error is fixed by this one line change.

diff --git a/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py b/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py
index ee86def6..a49e5f26 100644
--- a/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py
+++ b/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py
@@ -948,7 +948,7 @@ def main(args):
     # Add `time_cond_proj_dim` to the student U-Net if `teacher_unet.config.time_cond_proj_dim` is None
     if teacher_unet.config.time_cond_proj_dim is None:
         teacher_unet.config["time_cond_proj_dim"] = args.unet_time_cond_proj_dim
-    time_cond_proj_dim = teacher_unet.config.time_cond_proj_dim
+    time_cond_proj_dim = teacher_unet.config["time_cond_proj_dim"]
     unet = UNet2DConditionModel(**teacher_unet.config)
     # load teacher_unet weights into unet
     unet.load_state_dict(teacher_unet.state_dict(), strict=False)

@Neville0302
Copy link

请问您这个问题怎么解决的?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants