From b4e9dc4322dcc8a2e15dfb706418785dbf22fb0f Mon Sep 17 00:00:00 2001 From: Lionel Peer Date: Sun, 8 Oct 2023 15:09:03 +0200 Subject: [PATCH] added kl divergence --- README.md | 3 + diffusion_models/losses/elbo.py | 2 - diffusion_models/losses/kl_divergence.py | 75 ++++++++++++++++++++++++ diffusion_models/models/diffusion.py | 1 + tests/train_generative.py | 22 +++---- 5 files changed, 90 insertions(+), 13 deletions(-) delete mode 100644 diffusion_models/losses/elbo.py create mode 100644 diffusion_models/losses/kl_divergence.py diff --git a/README.md b/README.md index 3e24cfd..db959a3 100644 --- a/README.md +++ b/README.md @@ -20,6 +20,9 @@ sbatch --job-name=NAME --output=log/%j.out --gres=gpu:1 --mem=10G subscript.sh S ```bash srun --time 10 --partition=gpu.debug --constraint='titan_xp|geforce_gtx_titan_x' --gres=gpu:1 --pty bash -i ``` +```bash +sinfo -o "%f" +``` ## VSCode Remote Troubleshooting ### Repeated Password Query diff --git a/diffusion_models/losses/elbo.py b/diffusion_models/losses/elbo.py deleted file mode 100644 index 8be0774..0000000 --- a/diffusion_models/losses/elbo.py +++ /dev/null @@ -1,2 +0,0 @@ -import torch -from torch import nn \ No newline at end of file diff --git a/diffusion_models/losses/kl_divergence.py b/diffusion_models/losses/kl_divergence.py new file mode 100644 index 0000000..42db01b --- /dev/null +++ b/diffusion_models/losses/kl_divergence.py @@ -0,0 +1,75 @@ +import torch +from torch import nn, Tensor +from jaxtyping import Float + +def gaussian_kl( + p_mean: Float[Tensor, "1"], + p_var: Float[Tensor, "1"], + q_mean: Float[Tensor, "1"], + q_var: Float[Tensor, "1"] + ) -> Float[Tensor, "1"]: + """Calculate KL Divergence of 2 Gaussian distributions. + + KL divergence between two univariate Gaussians, as derived in [1], with k=1 (dimensionality). + + Parameters + ---------- + p_mean + mean value of first distribution + p_var + variance value of first distribution + q_mean + mean value of second distribution + q_var + variance value of second distribution + + Returns + ------- + out + KL divergence of inputs + + References + ---------- + .. math:: + D_{KL}(p||q) = \frac{1}{2}\left[\log\frac{|\Sigma_q|}{|\Sigma_p|} - k + (\boldsymbol{\mu_p}-\boldsymbol{\mu_q})^T\Sigma_q^{-1}(\boldsymbol{\mu_p}-\boldsymbol{\mu_q}) + tr\left\{\Sigma_q^{-1}\Sigma_p\right\}\right] + + .. [1] https://mr-easy.github.io/2020-04-16-kl-divergence-between-2-gaussian-distributions/ + + """ + return 0.5 * (torch.log(torch.abs(q_var) / torch.abs(p_var)) - 1.0 + ((p_mean-q_mean)**2)/q_var + p_var/q_var) + +def log_gaussian_kl( + p_mean: Float[Tensor, "1"], + p_logvar: Float[Tensor, "1"], + q_mean: Float[Tensor, "1"], + q_logvar: Float[Tensor, "1"] + ) -> Float[Tensor, "1"]: + """Calculate KL Divergence of 2 Gaussian distributions. + + KL divergence between two univariate Gaussians, as derived in [1], with k=1 (dimensionality) and log variances. + + Parameters + ---------- + p_mean + mean value of first distribution + p_logvar + log of variance value of first distribution + q_mean + mean value of second distribution + q_logvar + log of variance value of second distribution + + Returns + ------- + out + KL divergence of inputs + + References + ---------- + .. math:: + D_{KL}(p||q) = \frac{1}{2}\left[\log\frac{|\Sigma_q|}{|\Sigma_p|} - k + (\boldsymbol{\mu_p}-\boldsymbol{\mu_q})^T\Sigma_q^{-1}(\boldsymbol{\mu_p}-\boldsymbol{\mu_q}) + tr\left\{\Sigma_q^{-1}\Sigma_p\right\}\right] + + .. [1] https://mr-easy.github.io/2020-04-16-kl-divergence-between-2-gaussian-distributions/ + + """ + return 0.5 * (q_logvar - p_logvar - 1.0 + torch.exp(p_logvar - q_logvar) + ((p_mean - q_mean)**2)*torch.exp(-q_logvar)) \ No newline at end of file diff --git a/diffusion_models/models/diffusion.py b/diffusion_models/models/diffusion.py index 51a7d2a..d625644 100644 --- a/diffusion_models/models/diffusion.py +++ b/diffusion_models/models/diffusion.py @@ -196,6 +196,7 @@ def sample( noise = torch.randn_like(x, device=device) else: noise = torch.zeros_like(x, device=device) + # mean is predicted by NN and refactored by alphas, beta is kept constant according to scheduler x = 1 / torch.sqrt(alpha) * (x - ((1 - alpha) / (torch.sqrt(1 - alpha_hat))) * noise_pred) + torch.sqrt(beta) * noise if debugging and (i % save_every == 0): x_list.append(x) diff --git a/tests/train_generative.py b/tests/train_generative.py index 3119fc4..d4111b0 100644 --- a/tests/train_generative.py +++ b/tests/train_generative.py @@ -18,11 +18,17 @@ import torch.nn.functional as F config = dotdict( - total_epochs = 3000, - log_wandb = True, + total_epochs = 2, + log_wandb = False, project = "cifar_gen_trials", checkpoint_folder = "/itet-stor/peerli/net_scratch/cifarGenLong_checkpoints", - save_every = 10, + #data_path = os.path.abspath("./data"), + #checkpoint_folder = os.path.abspath(os.path.join("./data/checkpoints")), + data_path = "/itet-stor/peerli/net_scratch", + checkpoint_folder = "/itet-stor/peerli/net_scratch/cifar10cosine_checkpoints", + loss_func = F.mse_loss, + project = "cifar_gen_trials", + save_every = 1, num_samples = 9, show_denoising_history = False, show_history_every = 50, @@ -39,7 +45,7 @@ activation = nn.SiLU, backbone_enc_depth = 4, kernel_size = 3, - dropout = 0, + dropout = 0.1, forward_diff = ForwardDiffusion, max_timesteps = 1000, t_start = 0.0001, @@ -48,13 +54,7 @@ max_beta = 0.999, schedule_type = "linear", time_enc_dim = 256, - optimizer = torch.optim.Adam, - #data_path = os.path.abspath("./data"), - #checkpoint_folder = os.path.abspath(os.path.join("./data/checkpoints")), - data_path = "/itet-stor/peerli/net_scratch", - checkpoint_folder = "/itet-stor/peerli/net_scratch/mnistGen2_checkpoints", - loss_func = F.mse_loss, - project = "mnist_gen_trials" + optimizer = torch.optim.Adam ) def load_train_objs(config):