Skip to content

w_embedding = guidance_scale_embedding is missing in the train_lcm_distill_sdxl_wds.py #81

Open
@PetrByvsh

Description

@PetrByvsh

In the train_lcm_distill_sd_wds.py
# 20.4.6. Sample a random guidance scale w from U[w_min, w_max] and embed it
w = (args.w_max - args.w_min) * torch.rand((bsz,)) + args.w_min
w_embedding = guidance_scale_embedding(w, embedding_dim=args.unet_time_cond_proj_dim)
w = w.reshape(bsz, 1, 1, 1)
# Move to U-Net device and dtype
w = w.to(device=latents.device, dtype=latents.dtype)
w_embedding = w_embedding.to(device=latents.device, dtype=latents.dtype)

train_lcm_distill_sdxl_wds.py :
w = (args.w_max - args.w_min) * torch.rand((bsz,)) + args.w_min
w = w.reshape(bsz, 1, 1, 1)
w = w.to(device=latents.device, dtype=latents.dtype)

Any reason for this? The code for XL model does not work without it (it defines
noise_pred = unet(
noisy_model_input,
start_timesteps,
timestep_cond=None,
encoder_hidden_states=prompt_embeds.float(),
added_cond_kwargs=encoded_text,
).sample
The timestep_cond is None, although the "unet_time_cond_proj_dim" is still required as raised in the other issue.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions