Skip to content

Commit ec7d469

Browse files
committed
added flexible forward diffusion
1 parent 95590ea commit ec7d469

File tree

7 files changed

+289
-49
lines changed

7 files changed

+289
-49
lines changed

diffusion_models/models/diffusion.py

+92-4
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def __init__(self, timesteps: int, start: float=0.0001, end: float=0.02, offset:
4747
if self.type == "linear":
4848
self.init_betas = self._linear_scheduler(timesteps=self.timesteps, start=self.start, end=self.end)
4949
elif self.type == "cosine":
50-
self.init_betas = self._cosine_scheduler(timesteps=self.timesteps, start=self.start, end=self.end)
50+
self.init_betas = self._cosine_scheduler(timesteps=self.timesteps, offset=self.offset, max_beta=self.max_beta)
5151
else:
5252
raise NotImplementedError("Invalid scheduler option:", type)
5353
self.init_alphas = 1. - self.init_betas
@@ -90,15 +90,51 @@ def forward(
9090
x_t = sqrt_alphas_dash_t.view(-1, 1, 1, 1) * x_0
9191
x_t += sqrt_one_minus_alphas_dash_t.view(-1, 1, 1, 1) * noise_normal
9292
return x_t, noise_normal
93+
94+
def forward_flexible(
95+
self,
96+
x_t1: Float[Tensor, "batch channels height width"],
97+
t_1: Int64[Tensor, "batch"],
98+
t_2: Int64[Tensor, "batch"]
99+
) -> Float[Tensor, "batch channels height width"]:
100+
"""Flexible method that enables jumping from/to any timestep in the forward diffusion process.
101+
102+
Parameters
103+
----------
104+
x_t1
105+
batch of (partially noisy) inputs of different stages
106+
t_1
107+
initial timesteps of forward process (that above x_t1 are in at the moment)
108+
t_2
109+
timesteps that we would x_t1 transport to (elements must be larger than corresponding elements in t_1)
110+
"""
111+
diff = t_2 - t_1
112+
if diff[diff<0].shape[0] != 0:
113+
raise ValueError("Timesteps in forward process must increase.")
114+
noise_normal = torch.randn_like(x_t1, device=x_t1.device)
115+
if (True in torch.gt(t_1, self.timesteps-1)) or (True in torch.gt(t_2, self.timesteps-1)):
116+
raise IndexError("t ({}, {}) chosen larger than max. available t ({})".format(t_1, t_2, self.timesteps-1))
117+
batch_sqrt_alphas_dash = torch.zeros((t_1.shape[0]))
118+
batch_sqrt_one_minus_alpha_dash = torch.zeros((t_1.shape[0]))
119+
for sample in range(x_t1.shape[0]):
120+
alphas_interval = self.alphas[t_1[sample]:t_2[sample]+1]
121+
alphas_dash_interval = torch.cumprod(alphas_interval, axis=0)
122+
sqrt_alphas_dash_interval = torch.sqrt(alphas_dash_interval)
123+
sqrt_one_minus_alphas_dash_interval = torch.sqrt(1. - alphas_dash_interval)
124+
batch_sqrt_alphas_dash[sample] = sqrt_alphas_dash_interval
125+
batch_sqrt_one_minus_alpha_dash[sample] = sqrt_one_minus_alphas_dash_interval
126+
mean = batch_sqrt_alphas_dash.view(-1, 1, 1, 1) * x_t1
127+
out = mean + batch_sqrt_one_minus_alpha_dash.view(-1, 1, 1, 1) * noise_normal
128+
return out, noise_normal
93129

94130
def _linear_scheduler(self, timesteps, start, end):
95131
return torch.linspace(start, end, timesteps)
96132

97-
def _cosine_scheduler(self, timesteps, start, end):
133+
def _cosine_scheduler(self, timesteps, offset, max_beta):
98134
"""t is actually t/T from the paper"""
99-
return self._betas_for_alpha_bar(timesteps, lambda t: math.cos((t + self.offset) / (1.0 + self.offset) * math.pi / 2) ** 2, self.max_beta)
135+
return self._betas_for_alpha_bar(timesteps, lambda t: math.cos((t + offset) / (1.0 + offset) * math.pi / 2) ** 2, max_beta)
100136

101-
def _betas_for_alpha_bar(self, num_diffusion_timesteps, alpha_bar, max_beta=0.999):
137+
def _betas_for_alpha_bar(self, num_diffusion_timesteps, alpha_bar, max_beta):
102138
betas = []
103139
for i in range(num_diffusion_timesteps):
104140
t1 = i / num_diffusion_timesteps # t -> t/T
@@ -158,11 +194,63 @@ def forward(
158194
"""
159195
timesteps = self._sample_timesteps(x.shape[0], device=x.device)
160196
time_enc = self.time_encoder.get_pos_encoding(timesteps)
197+
# make (partially) noisy versions of batch, returns noisy version + applied noise
161198
x_t, noise = self.fwd_diff(x, timesteps)
199+
# predict the applied noise from the noisy version
162200
noise_pred = self.model(x_t, time_enc)
163201
return noise_pred, noise
164202

203+
def init_noise(self, num_samples: int):
204+
return torch.randn((num_samples, self.model.in_channels, self.img_size, self.img_size), device=list(self.parameters())[0].device)
205+
206+
def denoise_singlestep(
207+
self,
208+
x: Float[Tensor, "batch channels height width"],
209+
t: Int64[Tensor, "batch"]
210+
) -> Float[Tensor, "batch channels height width"]:
211+
"""Denoise single timestep in reverse direction.
212+
213+
Parameters
214+
----------
215+
x
216+
tensor representing a batch of noisy pictures (may be of different timesteps)
217+
t
218+
tensor representing the t timesteps for the batch
219+
220+
Returns
221+
-------
222+
out
223+
less noisy version (by one timestep)
224+
"""
225+
self.model.eval()
226+
with torch.no_grad():
227+
t_enc = self.time_encoder.get_pos_encoding(t)
228+
noise_pred = self.model(x, t_enc)
229+
alpha = self.fwd_diff.alphas[t][:, None, None, None]
230+
alpha_hat = self.fwd_diff.alphas_dash[t][:, None, None, None]
231+
beta = self.fwd_diff.betas[t][:, None, None, None]
232+
noise = torch.randn_like(x, device=noise_pred.device)
233+
# noise where t = 1 should be zero
234+
(t_one_idx, ) = torch.where(t==1)
235+
noise[t_one_idx] = 0
236+
x = 1 / torch.sqrt(alpha) * (x - ((1 - alpha) / (torch.sqrt(1 - alpha_hat))) * noise_pred) + torch.sqrt(beta) * noise
237+
self.model.train()
238+
return x
239+
165240
def sample(
241+
self,
242+
num_samples: int,
243+
debugging: bool=False,
244+
save_every: int=20
245+
) -> Float[Tensor, "batch channel height width"]:
246+
beta = self.fwd_diff.betas[-1].view(-1,1,1,1)
247+
x = self.init_noise(num_samples) * torch.sqrt(beta)
248+
for i in reversed(range(1, self.fwd_diff.timesteps)):
249+
t = i * torch.ones((num_samples), dtype=torch.long, device=list(self.model.parameters())[0].device)
250+
x = self.denoise_singlestep(x, t)
251+
return x
252+
253+
def sample2(
166254
self,
167255
num_samples: int,
168256
debugging: bool=False,

diffusion_models/models/unet.py

+128-5
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,10 @@ def __init__(
231231
dropout: float=0.5,
232232
activation: nn.Module=nn.SiLU,
233233
verbose: bool=False,
234-
init_channels: int=64
234+
init_channels: int=64,
235+
attention: bool=True,
236+
attention_heads: int=4,
237+
attention_ff_dim: int=None
235238
) -> None:
236239
"""Constructor of UNet.
237240
@@ -251,6 +254,14 @@ def __init__(
251254
activation function to be used
252255
verbose
253256
verbose printing of tensor shapes for debbugging
257+
init_channels
258+
number of channels to initially transform the input to (usually 64, 128, ...)
259+
attention
260+
whether to use self-attention layers
261+
attention_heads
262+
number of attention heads to be used
263+
attention_ff_dim
264+
hidden dimension of feedforward layer in self attention module, None defaults to input dimension
254265
"""
255266
super().__init__()
256267
self.num_layers = num_encoding_blocks
@@ -263,6 +274,9 @@ def __init__(
263274
self.activation = activation
264275
self.verbose = verbose
265276
self.init_channels = init_channels
277+
self.attention = attention
278+
self.attention_heads = attention_heads
279+
self.attention_ff_dim = attention_ff_dim
266280

267281
self.encoding_channels, self.decoding_channels = self._get_channel_lists(init_channels, num_encoding_blocks)
268282

@@ -273,7 +287,10 @@ def __init__(
273287
nn.Dropout(self.dropout)
274288
)
275289

276-
self.encoder = nn.ModuleList([EncodingBlock(self.encoding_channels[i], self.encoding_channels[i+1], time_emb_size, kernel_size, dropout, self.activation, verbose) for i in range(len(self.encoding_channels[:-1]))])
290+
if attention:
291+
self.encoder = nn.ModuleList([AttentionEncodingBlock(self.encoding_channels[i], self.encoding_channels[i+1], time_emb_size, kernel_size, dropout, self.activation, verbose, attention_heads, attention_ff_dim) for i in range(len(self.encoding_channels[:-1]))])
292+
else:
293+
self.encoder = nn.ModuleList([EncodingBlock(self.encoding_channels[i], self.encoding_channels[i+1], time_emb_size, kernel_size, dropout, self.activation, verbose) for i in range(len(self.encoding_channels[:-1]))])
277294

278295
self.bottleneck = nn.Sequential(
279296
nn.Conv2d(self.encoding_channels[-1], self.encoding_channels[-1] * 2, kernel_size=self.kernel_size, padding="same"),
@@ -286,8 +303,11 @@ def __init__(
286303
nn.Dropout(self.dropout)
287304
)
288305

289-
self.decoder = nn.ModuleList([DecodingBlock(self.decoding_channels[i], self.decoding_channels[i+1], time_emb_size, kernel_size, dropout, self.activation, verbose) for i in range(len(self.encoding_channels[:-1]))])
290-
306+
if attention:
307+
self.decoder = nn.ModuleList([AttentionDecodingBlock(self.decoding_channels[i], self.decoding_channels[i+1], time_emb_size, kernel_size, dropout, self.activation, verbose, attention_heads, attention_ff_dim) for i in range(len(self.encoding_channels[:-1]))])
308+
else:
309+
self.decoder = nn.ModuleList([DecodingBlock(self.decoding_channels[i], self.decoding_channels[i+1], time_emb_size, kernel_size, dropout, self.activation, verbose) for i in range(len(self.encoding_channels[:-1]))])
310+
291311
self.out_conv = nn.Conv2d(init_channels, in_channels, kernel_size=kernel_size, padding="same")
292312

293313
def _get_channel_lists(self, start_channels, num_layers):
@@ -367,4 +387,107 @@ def _check_sizes(self, x):
367387
heights = [(elem.is_integer() and (elem % 2 == 0)) for elem in heights]
368388
if (False in widths) or (False in heights):
369389
return False
370-
return True
390+
return True
391+
392+
class SelfAttention(nn.Module):
393+
def __init__(
394+
self,
395+
channels: int,
396+
num_heads: int,
397+
dropout: float,
398+
dim_feedforward: int=None,
399+
activation: nn.Module=nn.SiLU
400+
) -> None:
401+
"""Constructor of SelfAttention module.
402+
403+
Implementation of self-attention layer for image data.
404+
405+
Parameters
406+
----------
407+
channels
408+
number of input channels
409+
num_heads
410+
number of desired attention heads
411+
dropout
412+
dropout probability value
413+
dim_feedforward
414+
dimension of hidden layers in feedforward NN, defaults to number of input channels
415+
activation
416+
activation function to be used, as uninstantiated nn.Module
417+
"""
418+
super().__init__()
419+
self.channels = channels
420+
self.num_heads = num_heads
421+
self.dropout = dropout
422+
if dim_feedforward is not None:
423+
self.dim_feedforward = dim_feedforward
424+
else:
425+
self.dim_feedforward = channels
426+
self.activation = activation()
427+
self.attention_layer = nn.TransformerEncoderLayer(
428+
channels,
429+
num_heads,
430+
self.dim_feedforward,
431+
dropout,
432+
self.activation,
433+
batch_first=True
434+
)
435+
436+
def forward(self, x: Float[Tensor, "batch channels height width"]) -> Float[Tensor, "batch channels height width"]:
437+
"""Forward method of SelfAttention module.
438+
439+
Parameters
440+
----------
441+
x
442+
input tensor
443+
444+
Returns
445+
-------
446+
out
447+
output tensor
448+
"""
449+
# transform feature maps into vectors and put feature dimension (channels) at the end
450+
orig_ize = x.size()
451+
x = x.view(-1, x.shape[1], x.shape[2]*x.shape[3]).swapaxes(1,2)
452+
x = self.attention_layer(x)
453+
return x.swapaxes(1,2).view(*orig_ize)
454+
455+
class AttentionEncodingBlock(EncodingBlock):
456+
def __init__(
457+
self,
458+
in_channels: int,
459+
out_channels: int,
460+
time_embedding_size: int,
461+
kernel_size: int = 3,
462+
dropout: float = 0.5,
463+
activation: nn.Module = nn.SiLU,
464+
verbose: bool = False,
465+
attention_heads: int=4,
466+
attention_ff_dim: int=None
467+
) -> None:
468+
super().__init__(in_channels, out_channels, time_embedding_size, kernel_size, dropout, activation, verbose)
469+
self.sa = SelfAttention(out_channels, attention_heads, dropout, attention_ff_dim, activation)
470+
471+
def forward(self, x: Tensor, time_embedding: Tensor) -> Tuple[Tensor, Tensor]:
472+
out, skip = super().forward(x, time_embedding)
473+
return self.sa(out), skip
474+
475+
class AttentionDecodingBlock(DecodingBlock):
476+
def __init__(
477+
self,
478+
in_channels: int,
479+
out_channels: int,
480+
time_embedding_size: int,
481+
kernel_size: int = 3,
482+
dropout: float = 0.5,
483+
activation: nn.Module = nn.SiLU,
484+
verbose: bool = False,
485+
attention_heads: int=4,
486+
attention_ff_dim: int=None
487+
) -> None:
488+
super().__init__(in_channels, out_channels, time_embedding_size, kernel_size, dropout, activation, verbose)
489+
self.sa = SelfAttention(out_channels, attention_heads, dropout, attention_ff_dim, activation)
490+
491+
def forward(self, x: Tensor, skip: Tensor, time_embedding: Tensor = None) -> Tensor:
492+
out = super().forward(x, skip, time_embedding)
493+
return self.sa(out)

diffusion_models/utils/datasets.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ def __init__(self, root: str, train: bool = True, transform: Callable[..., Any]
1515
super().__init__(root, train, transform, target_transform, download)
1616

1717
class Cifar10DebugDataset(Cifar10Dataset):
18-
__len__ = lambda x: 100
18+
__len__ = lambda x: 5000
1919

2020
class MNISTTrainDataset(MNIST):
2121
def __init__(self, root: str, train: bool = True, transform: Callable[..., Any] | None = None, target_transform: Callable[..., Any] | None = None, download: bool = False) -> None:

tests/flexible_foward.py

+38
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import torch
2+
from train_generative import config
3+
import torchvision
4+
import torch.nn as nn
5+
from math import sqrt
6+
7+
model = config.architecture(
8+
backbone = config.backbone(
9+
num_encoding_blocks = 4,
10+
in_channels = 1,
11+
kernel_size = 3,
12+
dropout = config.dropout,
13+
activation = nn.SiLU,
14+
time_emb_size = config.time_enc_dim,
15+
init_channels = 128,
16+
attention = False,
17+
attention_heads = 0,
18+
attention_ff_dim = 0
19+
),
20+
fwd_diff = config.forward_diff(
21+
timesteps = config.max_timesteps,
22+
start = config.t_start,
23+
end = config.t_end,
24+
offset = config.offset,
25+
max_beta = config.max_beta,
26+
type = "linear"
27+
),
28+
img_size = config.img_size,
29+
time_enc_dim = config.time_enc_dim,
30+
dropout = config.dropout
31+
)
32+
33+
model = model.to("cuda")
34+
model.load_state_dict(torch.load("/itet-stor/peerli/net_scratch/ghoulish-goosebump-9/checkpoint90.pt"))
35+
36+
samples = model.sample(9)
37+
samples = torchvision.utils.make_grid(samples, nrow=int(sqrt(9)))
38+
torchvision.utils.save_image(samples, "/home/peerli/Downloads/sample2.png")

tests/job.sh

+2-2
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,12 @@
22
#SBATCH --account=student
33
#SBATCH --output=log/%j.out
44
#SBATCH --error=log/%j.err
5-
#SBATCH --gres=gpu:4
5+
#SBATCH --gres=gpu:2
66
#SBATCH --mem=32G
77
#SBATCH --job-name=mnist_double
88
#SBATCH --constraint='titan_xp|geforce_gtx_titan_x'
99

1010
source /scratch_net/biwidl311/peerli/conda/etc/profile.d/conda.sh
1111
conda activate liotorch
1212
mkdir log
13-
python -u train_generative.py "$@"
13+
python -u train_generative.py "$@"

0 commit comments

Comments
 (0)