Description
In the train_lcm_distill_sd_wds.py
# 20.4.6. Sample a random guidance scale w from U[w_min, w_max] and embed it
w = (args.w_max - args.w_min) * torch.rand((bsz,)) + args.w_min
w_embedding = guidance_scale_embedding(w, embedding_dim=args.unet_time_cond_proj_dim)
w = w.reshape(bsz, 1, 1, 1)
# Move to U-Net device and dtype
w = w.to(device=latents.device, dtype=latents.dtype)
w_embedding = w_embedding.to(device=latents.device, dtype=latents.dtype)
train_lcm_distill_sdxl_wds.py :
w = (args.w_max - args.w_min) * torch.rand((bsz,)) + args.w_min
w = w.reshape(bsz, 1, 1, 1)
w = w.to(device=latents.device, dtype=latents.dtype)
Any reason for this? The code for XL model does not work without it (it defines
noise_pred = unet(
noisy_model_input,
start_timesteps,
timestep_cond=None,
encoder_hidden_states=prompt_embeds.float(),
added_cond_kwargs=encoded_text,
).sample
The timestep_cond is None, although the "unet_time_cond_proj_dim" is still required as raised in the other issue.