Skip to content

Commit e4b29da

Browse files
committed
debugged flexible noising
1 parent c1f2e7d commit e4b29da

File tree

4 files changed

+96
-6
lines changed

4 files changed

+96
-6
lines changed

diffusion_models/models/diffusion.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,8 +121,8 @@ def forward_flexible(
121121
alphas_dash_interval = torch.cumprod(alphas_interval, axis=0)
122122
sqrt_alphas_dash_interval = torch.sqrt(alphas_dash_interval)
123123
sqrt_one_minus_alphas_dash_interval = torch.sqrt(1. - alphas_dash_interval)
124-
batch_sqrt_alphas_dash[sample] = sqrt_alphas_dash_interval
125-
batch_sqrt_one_minus_alpha_dash[sample] = sqrt_one_minus_alphas_dash_interval
124+
batch_sqrt_alphas_dash[sample] = sqrt_alphas_dash_interval[-1]
125+
batch_sqrt_one_minus_alpha_dash[sample] = sqrt_one_minus_alphas_dash_interval[-1]
126126
mean = batch_sqrt_alphas_dash.view(-1, 1, 1, 1) * x_t1
127127
out = mean + batch_sqrt_one_minus_alpha_dash.view(-1, 1, 1, 1) * noise_normal
128128
return out, noise_normal

tests/sampler_tests/flexible_denoising.ipynb

Whitespace-only changes.

tests/sampler_tests/flexible_noising.ipynb

Lines changed: 93 additions & 3 deletions
Large diffs are not rendered by default.

tests/sampler_tests/sampler_dutifulpond10.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ def run():
1313
config = None
1414
with open("config_dutifulpond10.json", "r") as f:
1515
config = json.load(f)
16-
ckp_num = 40
16+
ckp_num = 32
1717
ckpt_path = f"/itet-stor/peerli/net_scratch/dutiful-pond-10/checkpoint{ckp_num}.pt"
1818
device = torch.device("cuda")
1919
backbone = UNet(

0 commit comments

Comments
 (0)