|
x[i] = (vae.encode(img[None].bfloat16()).latent_dist.mode()[0] - vae_shift) * vae_scale |
I noticed in the code that the VAE uses mode() instead of sample() during training. Could you please explain the reason behind this choice?
From my understanding, most model training setups typically use sample() in VAEs. I’m curious if using mode() provides any particular advantages in this case.
Thank you!