Skip to content

Commit

Permalink
debugged flexible noising
Browse files Browse the repository at this point in the history
  • Loading branch information
liopeer committed Nov 9, 2023
1 parent c1f2e7d commit e4b29da
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 6 deletions.
4 changes: 2 additions & 2 deletions diffusion_models/models/diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,8 @@ def forward_flexible(
alphas_dash_interval = torch.cumprod(alphas_interval, axis=0)
sqrt_alphas_dash_interval = torch.sqrt(alphas_dash_interval)
sqrt_one_minus_alphas_dash_interval = torch.sqrt(1. - alphas_dash_interval)
batch_sqrt_alphas_dash[sample] = sqrt_alphas_dash_interval
batch_sqrt_one_minus_alpha_dash[sample] = sqrt_one_minus_alphas_dash_interval
batch_sqrt_alphas_dash[sample] = sqrt_alphas_dash_interval[-1]
batch_sqrt_one_minus_alpha_dash[sample] = sqrt_one_minus_alphas_dash_interval[-1]
mean = batch_sqrt_alphas_dash.view(-1, 1, 1, 1) * x_t1
out = mean + batch_sqrt_one_minus_alpha_dash.view(-1, 1, 1, 1) * noise_normal
return out, noise_normal
Expand Down
Empty file.
96 changes: 93 additions & 3 deletions tests/sampler_tests/flexible_noising.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion tests/sampler_tests/sampler_dutifulpond10.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def run():
config = None
with open("config_dutifulpond10.json", "r") as f:
config = json.load(f)
ckp_num = 40
ckp_num = 32
ckpt_path = f"/itet-stor/peerli/net_scratch/dutiful-pond-10/checkpoint{ckp_num}.pt"
device = torch.device("cuda")
backbone = UNet(
Expand Down

0 comments on commit e4b29da

Please sign in to comment.