Skip to content

Commit 2d9c488

Browse files
committed
ready for training
1 parent 0319f0b commit 2d9c488

File tree

9 files changed

+221
-72
lines changed

9 files changed

+221
-72
lines changed

diffusion_models/models/diffusion.py

Lines changed: 60 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,40 +1,8 @@
11
import torch
22
from torch import nn, Tensor
3-
from jaxtyping import Float, Int64
3+
from jaxtyping import Float, Int64, Int
44
from typing import Literal
55

6-
class DiffusionModel(nn.Module):
7-
def __init__(
8-
self,
9-
backbone: nn.Module,
10-
timesteps: int,
11-
t_start: float=0.0001,
12-
t_end: float=0.02,
13-
schedule_type: Literal["linear", "cosine"]="linear"
14-
) -> None:
15-
super().__init__()
16-
self.model = backbone
17-
self.fwd_diff = ForwardDiffusion(timesteps, t_start, t_end, schedule_type)
18-
19-
def forward(self, x):
20-
t = self._sample_timestep(x.shape[0])
21-
t = t.unsqueeze(-1).type(torch.float)
22-
t = self._pos_encoding(t, self.time_dim)
23-
x_t, noise = self.fwd_diff(x, t)
24-
noise_pred = self.model(x_t, t)
25-
return noise_pred, noise
26-
27-
def _pos_encoding(self, t, channels):
28-
inv_freq = 1.0 / (10000 ** (torch.arange(0, channels, 2, device=self.device).float() / channels))
29-
pos_enc_a = torch.sin(t.repeat(1, channels // 2) * inv_freq)
30-
pos_enc_b = torch.cos(t.repeat(1, channels // 2) * inv_freq)
31-
pos_enc = torch.cat([pos_enc_a, pos_enc_b], dim=-1)
32-
return pos_enc
33-
34-
def _sample_timestep(self, batch_size: int) -> Int64[Tensor, "batch"]:
35-
return torch.randint(low=1, high=self.fwd_diff.noise_steps, size=(batch_size,))
36-
37-
386
class ForwardDiffusion(nn.Module):
397
"""Class for forward diffusion process in DDPMs (denoising diffusion probabilistic models).
408
@@ -82,7 +50,11 @@ def __init__(self, timesteps: int, start: float=0.0001, end: float=0.02, type: L
8250

8351
self.register_buffer("noise_normal", torch.empty((1)), persistent=False)
8452

85-
def forward(self, x_0: Float[Tensor, "batch channels height width"], t: int) -> Float[Tensor, "batch channels height width"]:
53+
def forward(
54+
self,
55+
x_0: Float[Tensor, "batch channels height width"],
56+
t: Int[Tensor, "batch"]
57+
) -> Float[Tensor, "batch channels height width"]:
8658
"""Forward method of ForwardDiffusion class.
8759
8860
Parameters
@@ -98,15 +70,65 @@ def forward(self, x_0: Float[Tensor, "batch channels height width"], t: int) ->
9870
tensor with applied noise according to schedule and chosen timestep
9971
"""
10072
self.noise_normal = torch.randn_like(x_0)
101-
if t > self.timesteps-1:
73+
if True in torch.gt(t, self.timesteps-1):
10274
raise IndexError("t ({}) chosen larger than max. available t ({})".format(t, self.timesteps-1))
10375
sqrt_alpha_dash_t = self.sqrt_alphas_dash[t]
10476
sqrt_one_minus_alpha_dash_t = self.sqrt_one_minus_alpha_dash[t]
105-
x_t = sqrt_alpha_dash_t * x_0 + sqrt_one_minus_alpha_dash_t * self.noise_normal
106-
return x_t
77+
x_t = sqrt_alpha_dash_t.view(-1, 1, 1, 1) * x_0
78+
x_t += sqrt_one_minus_alpha_dash_t.view(-1, 1, 1, 1) * self.noise_normal
79+
return x_t, self.noise_normal
10780

10881
def _linear_scheduler(self, timesteps, start, end):
10982
return torch.linspace(start, end, timesteps)
11083

11184
def _cosine_scheduler(self, timesteps, start, end):
112-
raise NotImplementedError("Cosine scheduler not implemented yet.")
85+
raise NotImplementedError("Cosine scheduler not implemented yet.")
86+
87+
class DiffusionModel(nn.Module):
88+
def __init__(
89+
self,
90+
backbone: nn.Module,
91+
fwd_diff: ForwardDiffusion,
92+
time_enc_dim: int=256
93+
) -> None:
94+
super().__init__()
95+
self.model = backbone
96+
self.fwd_diff = fwd_diff
97+
self.time_enc_dim = time_enc_dim
98+
99+
self.register_buffer("timesteps", torch.empty((1)), persistent=False)
100+
self.register_buffer("time_enc", torch.empty((1)), persistent=False)
101+
102+
def forward(self, x):
103+
# sample batch of timesteps and create batch of positional/time encodings
104+
self.timesteps = self._sample_timesteps(x.shape[0])
105+
106+
# convert timesteps into time encodings
107+
self.time_enc = self._time_encoding(self.timesteps, self.time_enc_dim)
108+
109+
# create batch of noisy images
110+
x_t, noise = self.fwd_diff(x, self.timesteps)
111+
112+
# run noisy images, conditioned on time through model
113+
noise_pred = self.model(x_t, self.time_enc)
114+
return noise_pred, noise
115+
116+
def sample(self, n):
117+
"""Sample a batch of images."""
118+
pass
119+
120+
def _time_encoding(
121+
self,
122+
t: Int[Tensor, "batch"],
123+
channels: int
124+
) -> Float[Tensor, "batch time_enc_dim"]:
125+
t = t.unsqueeze(-1).type(torch.float)
126+
inv_freq = 1.0 / (10000 ** (torch.arange(0, channels, 2).float() / channels))
127+
inv_freq = inv_freq.to(t.device)
128+
pos_enc_a = torch.sin(t.repeat(1, channels // 2) * inv_freq)
129+
pos_enc_b = torch.cos(t.repeat(1, channels // 2) * inv_freq)
130+
pos_enc = torch.cat([pos_enc_a, pos_enc_b], dim=-1)
131+
return pos_enc
132+
133+
def _sample_timesteps(self, batch_size: int) -> Int64[Tensor, "batch"]:
134+
return torch.randint(low=1, high=self.fwd_diff.timesteps, size=(batch_size,))

diffusion_models/models/unet.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,8 @@ def __init__(
119119
self.dropout = dropout
120120
self.verbose = verbose
121121

122+
self.time_embedding_fc = nn.Linear(self.time_embedding_size, self.out_channels)
123+
122124
self.scale = nn.ConvTranspose2d(self.in_channels, self.out_channels, kernel_size=2, stride=2)
123125
self.conv1 = nn.Sequential(
124126
nn.Conv2d(self.out_channels * 2, self.out_channels, kernel_size=self.kernel_size, padding="same"),
@@ -156,16 +158,29 @@ def forward(
156158
"""
157159
if self.verbose:
158160
print(f"Decoder Input: {x.shape}\tSkip: {skip.shape}")
161+
159162
x = self.scale(x)
163+
160164
if self.verbose:
161165
print(f"After Scaling: {x.shape}")
166+
162167
x = torch.cat([x, skip], dim=1)
168+
163169
if self.verbose:
164170
print(f"After Concat {x.shape}")
171+
165172
x = self.conv1(x)
173+
174+
if time_embedding is not None:
175+
time_embedding = self.time_embedding_fc(time_embedding)
176+
time_embedding = time_embedding.view(time_embedding.shape[0], time_embedding.shape[1], 1, 1)
177+
x = x + time_embedding.expand(time_embedding.shape[0], time_embedding.shape[1], x.shape[-2], x.shape[-1])
178+
166179
if self.verbose:
167180
print(f"After Conv1: {x.shape}")
181+
168182
x = self.conv2(x)
183+
169184
if self.verbose:
170185
print(f"After Conv2: {x.shape}")
171186
return x
@@ -256,7 +271,7 @@ def forward(
256271
print("Encoding Channels", self.encoding_channels, "\tDecoding Channels", self.decoding_channels)
257272
if not self._check_sizes(x):
258273
raise ValueError("Choose appropriate image size.")
259-
274+
260275
# in_layer - to 64 channels
261276
x = self.in_conv(x)
262277

Original file line numberDiff line numberDiff line change
@@ -1,16 +1,22 @@
1-
from typing import Callable, Optional
2-
from torchvision.datasets import MNIST
3-
from torchvision.transforms import Compose, ToTensor, Normalize
1+
from typing import Callable, Optional, Tuple
2+
from torchvision.datasets import MNIST, CIFAR10
3+
from torchvision.transforms import Compose, ToTensor, Normalize, Resize
44
from typing import Any
55

6-
class MNISTTrainLoader(MNIST):
6+
class UnconditionedCifar10Dataset(CIFAR10):
77
def __init__(self, root: str, train: bool = True, transform: Callable[..., Any] | None = None, target_transform: Callable[..., Any] | None = None, download: bool = False) -> None:
8-
transform = Compose([ToTensor(), Normalize((0.1307,), (0.3081,))])
8+
transform = Compose([ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
9+
download = True
10+
super().__init__(root, train, transform, target_transform, download)
11+
12+
class MNISTTrainDataset(MNIST):
13+
def __init__(self, root: str, train: bool = True, transform: Callable[..., Any] | None = None, target_transform: Callable[..., Any] | None = None, download: bool = False) -> None:
14+
transform = Compose([ToTensor(), Normalize((0.1307,), (0.3081,)), Resize((32,32))])
915
download = True
1016
super().__init__(root, train, transform, target_transform, download)
1117

12-
class MNISTTestLoader(MNIST):
18+
class MNISTTestDataset(MNIST):
1319
def __init__(self, root: str, train: bool = True, transform: Callable[..., Any] | None = None, target_transform: Callable[..., Any] | None = None, download: bool = False) -> None:
1420
transform = Compose([ToTensor(), Normalize((0.1307,), (0.3081,))])
15-
download = False
21+
download = True
1622
super().__init__(root, train, transform, target_transform, download)

diffusion_models/utils/helpers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
class dotdict(dict):
22
"""dot.notation access to dictionary attributes"""
3-
__getattr__ = dict.get
3+
__getattr__ = dict.__getitem__
44
__setattr__ = dict.__setitem__
55
__delattr__ = dict.__delitem__

diffusion_models/utils/trainer.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import numpy as np
99
from time import time
1010
import wandb
11-
from typing import Callable, Literal, Any
11+
from typing import Callable, Literal, Any, Tuple
1212
import wandb
1313
from torch.nn import Module
1414

@@ -83,9 +83,9 @@ def _run_epoch(self, epoch):
8383
time1 = time()
8484
for data in self.train_data:
8585
if self.device_type == "cuda":
86-
data = map(lambda x: x.to(self.gpu_id), data)
86+
data = tuple(map(lambda x: x.to(self.gpu_id), data))
8787
else:
88-
data = map(lambda x: x.to(self.device_type), data)
88+
data = tuple(map(lambda x: x.to(self.device_type), data))
8989
batch_loss = self._run_batch(data)
9090
epoch_losses.append(batch_loss)
9191
if self.log_wandb:
@@ -118,6 +118,13 @@ def __init__(self, model: Module, train_data: Dataset, loss_func: Callable[...,
118118
super().__init__(model, train_data, loss_func, optimizer, gpu_id, batch_size, save_every, checkpoint_folder, device_type, log_wandb)
119119

120120
def _run_batch(self, data):
121+
"""Run a data batch.
122+
123+
Parameters
124+
----------
125+
data
126+
tuple of training batch and targets
127+
"""
121128
source, targets = data
122129
self.optimizer.zero_grad()
123130
pred = self.model(source)
@@ -131,5 +138,17 @@ def __init__(self, model: Module, train_data: Dataset, loss_func: Callable[...,
131138
super().__init__(model, train_data, loss_func, optimizer, gpu_id, batch_size, save_every, checkpoint_folder, device_type, log_wandb)
132139

133140
def _run_batch(self, data):
141+
"""Run a data batch.
142+
143+
Parameters
144+
----------
145+
data
146+
single item tuple of training batch
147+
"""
134148
self.optimizer.zero_grad()
135-
raise NotImplementedError("not finished yet")
149+
### to be changed!
150+
pred = self.model(data[0])
151+
loss = self.loss_func(*pred)
152+
loss.backward()
153+
self.optimizer.step()
154+
return loss.item()

tests/tests.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
import context
2+
from utils.datasets import UnconditionedCifar10Dataset
3+
from torch.utils.data import DataLoader
4+
5+
ds = UnconditionedCifar10Dataset("./data")
6+
dl = DataLoader(ds, batch_size=10)
7+
8+
k = next(iter(dl))
9+
print(type(k))

tests/train_discriminative.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import torch.multiprocessing as mp
1111
import os
1212
from utils.mp_setup import DDP_Proc_Group
13-
from utils.dataloaders import MNISTTrainLoader
13+
from utils.datasets import MNISTTrainDataset
1414
from utils.helpers import dotdict
1515
import wandb
1616
import torch.nn.functional as F
@@ -20,7 +20,7 @@
2020
batch_size = 1000,
2121
learning_rate = 0.001,
2222
device_type = "cpu",
23-
dataloader = MNISTTrainLoader,
23+
dataloader = MNISTTrainDataset,
2424
architecture = MNISTEncoder,
2525
out_classes = 10,
2626
optimizer = torch.optim.Adam,

tests/train_generative.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
import context
2+
from torchvision.transforms import ToTensor, Compose, Normalize
3+
from torch.utils.data import DataLoader
4+
import torch
5+
import torch.nn as nn
6+
from models.mnist_enc import MNISTEncoder
7+
from models.unet import UNet
8+
from models.diffusion import DiffusionModel, ForwardDiffusion
9+
import numpy as np
10+
from time import time
11+
from utils.trainer import DiscriminativeTrainer, GenerativeTrainer
12+
import torch.multiprocessing as mp
13+
import os
14+
from utils.mp_setup import DDP_Proc_Group
15+
from utils.datasets import MNISTTrainDataset, UnconditionedCifar10Dataset
16+
from utils.helpers import dotdict
17+
import wandb
18+
import torch.nn.functional as F
19+
20+
config = dotdict(
21+
total_epochs = 2,
22+
batch_size = 1000,
23+
learning_rate = 0.001,
24+
device_type = "cpu",
25+
dataset = MNISTTrainDataset,
26+
architecture = DiffusionModel,
27+
backbone = UNet,
28+
in_channels = 1,
29+
backbone_enc_depth = 4,
30+
kernel_size = 3,
31+
dropout = 0.5,
32+
forward_diff = ForwardDiffusion,
33+
max_timesteps = 1000,
34+
t_start = 0.0001,
35+
t_end = 0.02,
36+
schedule_type = "linear",
37+
time_enc_dim = 256,
38+
optimizer = torch.optim.Adam,
39+
data_path = os.path.abspath("./data"),
40+
checkpoint_folder = os.path.abspath(os.path.join("./data/checkpoints")),
41+
#data_path = "/itet-stor/peerli/net_scratch",
42+
#checkpoint_folder = "/itet-stor/peerli/net_scratch/mnist_checkpoints",
43+
save_every = 10,
44+
loss_func = F.mse_loss,
45+
log_wandb = False
46+
)
47+
48+
backbone = UNet(4)
49+
fwd_diff = ForwardDiffusion(timesteps=1000)
50+
model = DiffusionModel(backbone, fwd_diff)
51+
52+
def load_train_objs(config):
53+
train_set = config.dataset(config.data_path)
54+
model = config.architecture(
55+
config.backbone(
56+
num_encoding_blocks = config.backbone_enc_depth,
57+
in_channels = config.in_channels,
58+
kernel_size = config.kernel_size,
59+
dropout = config.dropout,
60+
time_emb_size = config.time_enc_dim
61+
),
62+
config.forward_diff(
63+
config.max_timesteps,
64+
config.t_start,
65+
config.t_end,
66+
config.schedule_type
67+
),
68+
config.time_enc_dim
69+
)
70+
optimizer = config.optimizer(model.parameters(), lr=config.learning_rate)
71+
return train_set, model, optimizer
72+
73+
def training(rank, world_size, config):
74+
if (rank == 0) and (config.log_wandb):
75+
wandb.init(project="mnist_trials", config=config, save_code=True)
76+
dataset, model, optimizer = load_train_objs(config)
77+
trainer = GenerativeTrainer(
78+
model,
79+
dataset,
80+
config.loss_func,
81+
optimizer,
82+
rank,
83+
config.batch_size,
84+
config.save_every,
85+
config.checkpoint_folder,
86+
config.device_type,
87+
config.log_wandb
88+
)
89+
trainer.train(config.total_epochs)
90+
91+
if __name__ == "__main__":
92+
if config.device_type == "cuda":
93+
world_size = torch.cuda.device_count()
94+
print("Device Count:", world_size)
95+
mp.spawn(DDP_Proc_Group(training), args=(world_size, config), nprocs=world_size)
96+
else:
97+
training(0, 0, config)

0 commit comments

Comments
 (0)