Skip to content

Commit

Permalink
added kl divergence
Browse files Browse the repository at this point in the history
  • Loading branch information
liopeer committed Oct 8, 2023
1 parent ab1dca8 commit b4e9dc4
Show file tree
Hide file tree
Showing 5 changed files with 90 additions and 13 deletions.
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ sbatch --job-name=NAME --output=log/%j.out --gres=gpu:1 --mem=10G subscript.sh S
```bash
srun --time 10 --partition=gpu.debug --constraint='titan_xp|geforce_gtx_titan_x' --gres=gpu:1 --pty bash -i
```
```bash
sinfo -o "%f"
```

## VSCode Remote Troubleshooting
### Repeated Password Query
Expand Down
2 changes: 0 additions & 2 deletions diffusion_models/losses/elbo.py

This file was deleted.

75 changes: 75 additions & 0 deletions diffusion_models/losses/kl_divergence.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import torch
from torch import nn, Tensor
from jaxtyping import Float

def gaussian_kl(
p_mean: Float[Tensor, "1"],
p_var: Float[Tensor, "1"],
q_mean: Float[Tensor, "1"],
q_var: Float[Tensor, "1"]
) -> Float[Tensor, "1"]:
"""Calculate KL Divergence of 2 Gaussian distributions.
KL divergence between two univariate Gaussians, as derived in [1], with k=1 (dimensionality).
Parameters
----------
p_mean
mean value of first distribution
p_var
variance value of first distribution
q_mean
mean value of second distribution
q_var
variance value of second distribution
Returns
-------
out
KL divergence of inputs
References
----------
.. math::
D_{KL}(p||q) = \frac{1}{2}\left[\log\frac{|\Sigma_q|}{|\Sigma_p|} - k + (\boldsymbol{\mu_p}-\boldsymbol{\mu_q})^T\Sigma_q^{-1}(\boldsymbol{\mu_p}-\boldsymbol{\mu_q}) + tr\left\{\Sigma_q^{-1}\Sigma_p\right\}\right]
.. [1] https://mr-easy.github.io/2020-04-16-kl-divergence-between-2-gaussian-distributions/
"""
return 0.5 * (torch.log(torch.abs(q_var) / torch.abs(p_var)) - 1.0 + ((p_mean-q_mean)**2)/q_var + p_var/q_var)

def log_gaussian_kl(
p_mean: Float[Tensor, "1"],
p_logvar: Float[Tensor, "1"],
q_mean: Float[Tensor, "1"],
q_logvar: Float[Tensor, "1"]
) -> Float[Tensor, "1"]:
"""Calculate KL Divergence of 2 Gaussian distributions.
KL divergence between two univariate Gaussians, as derived in [1], with k=1 (dimensionality) and log variances.
Parameters
----------
p_mean
mean value of first distribution
p_logvar
log of variance value of first distribution
q_mean
mean value of second distribution
q_logvar
log of variance value of second distribution
Returns
-------
out
KL divergence of inputs
References
----------
.. math::
D_{KL}(p||q) = \frac{1}{2}\left[\log\frac{|\Sigma_q|}{|\Sigma_p|} - k + (\boldsymbol{\mu_p}-\boldsymbol{\mu_q})^T\Sigma_q^{-1}(\boldsymbol{\mu_p}-\boldsymbol{\mu_q}) + tr\left\{\Sigma_q^{-1}\Sigma_p\right\}\right]
.. [1] https://mr-easy.github.io/2020-04-16-kl-divergence-between-2-gaussian-distributions/
"""
return 0.5 * (q_logvar - p_logvar - 1.0 + torch.exp(p_logvar - q_logvar) + ((p_mean - q_mean)**2)*torch.exp(-q_logvar))
1 change: 1 addition & 0 deletions diffusion_models/models/diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ def sample(
noise = torch.randn_like(x, device=device)
else:
noise = torch.zeros_like(x, device=device)
# mean is predicted by NN and refactored by alphas, beta is kept constant according to scheduler
x = 1 / torch.sqrt(alpha) * (x - ((1 - alpha) / (torch.sqrt(1 - alpha_hat))) * noise_pred) + torch.sqrt(beta) * noise
if debugging and (i % save_every == 0):
x_list.append(x)
Expand Down
22 changes: 11 additions & 11 deletions tests/train_generative.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,17 @@
import torch.nn.functional as F

config = dotdict(
total_epochs = 3000,
log_wandb = True,
total_epochs = 2,
log_wandb = False,
project = "cifar_gen_trials",
checkpoint_folder = "/itet-stor/peerli/net_scratch/cifarGenLong_checkpoints",
save_every = 10,
#data_path = os.path.abspath("./data"),
#checkpoint_folder = os.path.abspath(os.path.join("./data/checkpoints")),
data_path = "/itet-stor/peerli/net_scratch",
checkpoint_folder = "/itet-stor/peerli/net_scratch/cifar10cosine_checkpoints",
loss_func = F.mse_loss,
project = "cifar_gen_trials",
save_every = 1,
num_samples = 9,
show_denoising_history = False,
show_history_every = 50,
Expand All @@ -39,7 +45,7 @@
activation = nn.SiLU,
backbone_enc_depth = 4,
kernel_size = 3,
dropout = 0,
dropout = 0.1,
forward_diff = ForwardDiffusion,
max_timesteps = 1000,
t_start = 0.0001,
Expand All @@ -48,13 +54,7 @@
max_beta = 0.999,
schedule_type = "linear",
time_enc_dim = 256,
optimizer = torch.optim.Adam,
#data_path = os.path.abspath("./data"),
#checkpoint_folder = os.path.abspath(os.path.join("./data/checkpoints")),
data_path = "/itet-stor/peerli/net_scratch",
checkpoint_folder = "/itet-stor/peerli/net_scratch/mnistGen2_checkpoints",
loss_func = F.mse_loss,
project = "mnist_gen_trials"
optimizer = torch.optim.Adam
)

def load_train_objs(config):
Expand Down

0 comments on commit b4e9dc4

Please sign in to comment.