Skip to content

Commit

Permalink
ready for training
Browse files Browse the repository at this point in the history
  • Loading branch information
liopeer committed Sep 22, 2023
1 parent 0319f0b commit 2d9c488
Show file tree
Hide file tree
Showing 9 changed files with 221 additions and 72 deletions.
98 changes: 60 additions & 38 deletions diffusion_models/models/diffusion.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,8 @@
import torch
from torch import nn, Tensor
from jaxtyping import Float, Int64
from jaxtyping import Float, Int64, Int
from typing import Literal

class DiffusionModel(nn.Module):
def __init__(
self,
backbone: nn.Module,
timesteps: int,
t_start: float=0.0001,
t_end: float=0.02,
schedule_type: Literal["linear", "cosine"]="linear"
) -> None:
super().__init__()
self.model = backbone
self.fwd_diff = ForwardDiffusion(timesteps, t_start, t_end, schedule_type)

def forward(self, x):
t = self._sample_timestep(x.shape[0])
t = t.unsqueeze(-1).type(torch.float)
t = self._pos_encoding(t, self.time_dim)
x_t, noise = self.fwd_diff(x, t)
noise_pred = self.model(x_t, t)
return noise_pred, noise

def _pos_encoding(self, t, channels):
inv_freq = 1.0 / (10000 ** (torch.arange(0, channels, 2, device=self.device).float() / channels))
pos_enc_a = torch.sin(t.repeat(1, channels // 2) * inv_freq)
pos_enc_b = torch.cos(t.repeat(1, channels // 2) * inv_freq)
pos_enc = torch.cat([pos_enc_a, pos_enc_b], dim=-1)
return pos_enc

def _sample_timestep(self, batch_size: int) -> Int64[Tensor, "batch"]:
return torch.randint(low=1, high=self.fwd_diff.noise_steps, size=(batch_size,))


class ForwardDiffusion(nn.Module):
"""Class for forward diffusion process in DDPMs (denoising diffusion probabilistic models).
Expand Down Expand Up @@ -82,7 +50,11 @@ def __init__(self, timesteps: int, start: float=0.0001, end: float=0.02, type: L

self.register_buffer("noise_normal", torch.empty((1)), persistent=False)

def forward(self, x_0: Float[Tensor, "batch channels height width"], t: int) -> Float[Tensor, "batch channels height width"]:
def forward(
self,
x_0: Float[Tensor, "batch channels height width"],
t: Int[Tensor, "batch"]
) -> Float[Tensor, "batch channels height width"]:
"""Forward method of ForwardDiffusion class.
Parameters
Expand All @@ -98,15 +70,65 @@ def forward(self, x_0: Float[Tensor, "batch channels height width"], t: int) ->
tensor with applied noise according to schedule and chosen timestep
"""
self.noise_normal = torch.randn_like(x_0)
if t > self.timesteps-1:
if True in torch.gt(t, self.timesteps-1):
raise IndexError("t ({}) chosen larger than max. available t ({})".format(t, self.timesteps-1))
sqrt_alpha_dash_t = self.sqrt_alphas_dash[t]
sqrt_one_minus_alpha_dash_t = self.sqrt_one_minus_alpha_dash[t]
x_t = sqrt_alpha_dash_t * x_0 + sqrt_one_minus_alpha_dash_t * self.noise_normal
return x_t
x_t = sqrt_alpha_dash_t.view(-1, 1, 1, 1) * x_0
x_t += sqrt_one_minus_alpha_dash_t.view(-1, 1, 1, 1) * self.noise_normal
return x_t, self.noise_normal

def _linear_scheduler(self, timesteps, start, end):
return torch.linspace(start, end, timesteps)

def _cosine_scheduler(self, timesteps, start, end):
raise NotImplementedError("Cosine scheduler not implemented yet.")
raise NotImplementedError("Cosine scheduler not implemented yet.")

class DiffusionModel(nn.Module):
def __init__(
self,
backbone: nn.Module,
fwd_diff: ForwardDiffusion,
time_enc_dim: int=256
) -> None:
super().__init__()
self.model = backbone
self.fwd_diff = fwd_diff
self.time_enc_dim = time_enc_dim

self.register_buffer("timesteps", torch.empty((1)), persistent=False)
self.register_buffer("time_enc", torch.empty((1)), persistent=False)

def forward(self, x):
# sample batch of timesteps and create batch of positional/time encodings
self.timesteps = self._sample_timesteps(x.shape[0])

# convert timesteps into time encodings
self.time_enc = self._time_encoding(self.timesteps, self.time_enc_dim)

# create batch of noisy images
x_t, noise = self.fwd_diff(x, self.timesteps)

# run noisy images, conditioned on time through model
noise_pred = self.model(x_t, self.time_enc)
return noise_pred, noise

def sample(self, n):
"""Sample a batch of images."""
pass

def _time_encoding(
self,
t: Int[Tensor, "batch"],
channels: int
) -> Float[Tensor, "batch time_enc_dim"]:
t = t.unsqueeze(-1).type(torch.float)
inv_freq = 1.0 / (10000 ** (torch.arange(0, channels, 2).float() / channels))
inv_freq = inv_freq.to(t.device)
pos_enc_a = torch.sin(t.repeat(1, channels // 2) * inv_freq)
pos_enc_b = torch.cos(t.repeat(1, channels // 2) * inv_freq)
pos_enc = torch.cat([pos_enc_a, pos_enc_b], dim=-1)
return pos_enc

def _sample_timesteps(self, batch_size: int) -> Int64[Tensor, "batch"]:
return torch.randint(low=1, high=self.fwd_diff.timesteps, size=(batch_size,))
17 changes: 16 additions & 1 deletion diffusion_models/models/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,8 @@ def __init__(
self.dropout = dropout
self.verbose = verbose

self.time_embedding_fc = nn.Linear(self.time_embedding_size, self.out_channels)

self.scale = nn.ConvTranspose2d(self.in_channels, self.out_channels, kernel_size=2, stride=2)
self.conv1 = nn.Sequential(
nn.Conv2d(self.out_channels * 2, self.out_channels, kernel_size=self.kernel_size, padding="same"),
Expand Down Expand Up @@ -156,16 +158,29 @@ def forward(
"""
if self.verbose:
print(f"Decoder Input: {x.shape}\tSkip: {skip.shape}")

x = self.scale(x)

if self.verbose:
print(f"After Scaling: {x.shape}")

x = torch.cat([x, skip], dim=1)

if self.verbose:
print(f"After Concat {x.shape}")

x = self.conv1(x)

if time_embedding is not None:
time_embedding = self.time_embedding_fc(time_embedding)
time_embedding = time_embedding.view(time_embedding.shape[0], time_embedding.shape[1], 1, 1)
x = x + time_embedding.expand(time_embedding.shape[0], time_embedding.shape[1], x.shape[-2], x.shape[-1])

if self.verbose:
print(f"After Conv1: {x.shape}")

x = self.conv2(x)

if self.verbose:
print(f"After Conv2: {x.shape}")
return x
Expand Down Expand Up @@ -256,7 +271,7 @@ def forward(
print("Encoding Channels", self.encoding_channels, "\tDecoding Channels", self.decoding_channels)
if not self._check_sizes(x):
raise ValueError("Choose appropriate image size.")

# in_layer - to 64 channels
x = self.in_conv(x)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,16 +1,22 @@
from typing import Callable, Optional
from torchvision.datasets import MNIST
from torchvision.transforms import Compose, ToTensor, Normalize
from typing import Callable, Optional, Tuple
from torchvision.datasets import MNIST, CIFAR10
from torchvision.transforms import Compose, ToTensor, Normalize, Resize
from typing import Any

class MNISTTrainLoader(MNIST):
class UnconditionedCifar10Dataset(CIFAR10):
def __init__(self, root: str, train: bool = True, transform: Callable[..., Any] | None = None, target_transform: Callable[..., Any] | None = None, download: bool = False) -> None:
transform = Compose([ToTensor(), Normalize((0.1307,), (0.3081,))])
transform = Compose([ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
download = True
super().__init__(root, train, transform, target_transform, download)

class MNISTTrainDataset(MNIST):
def __init__(self, root: str, train: bool = True, transform: Callable[..., Any] | None = None, target_transform: Callable[..., Any] | None = None, download: bool = False) -> None:
transform = Compose([ToTensor(), Normalize((0.1307,), (0.3081,)), Resize((32,32))])
download = True
super().__init__(root, train, transform, target_transform, download)

class MNISTTestLoader(MNIST):
class MNISTTestDataset(MNIST):
def __init__(self, root: str, train: bool = True, transform: Callable[..., Any] | None = None, target_transform: Callable[..., Any] | None = None, download: bool = False) -> None:
transform = Compose([ToTensor(), Normalize((0.1307,), (0.3081,))])
download = False
download = True
super().__init__(root, train, transform, target_transform, download)
2 changes: 1 addition & 1 deletion diffusion_models/utils/helpers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
class dotdict(dict):
"""dot.notation access to dictionary attributes"""
__getattr__ = dict.get
__getattr__ = dict.__getitem__
__setattr__ = dict.__setitem__
__delattr__ = dict.__delitem__
27 changes: 23 additions & 4 deletions diffusion_models/utils/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import numpy as np
from time import time
import wandb
from typing import Callable, Literal, Any
from typing import Callable, Literal, Any, Tuple
import wandb
from torch.nn import Module

Expand Down Expand Up @@ -83,9 +83,9 @@ def _run_epoch(self, epoch):
time1 = time()
for data in self.train_data:
if self.device_type == "cuda":
data = map(lambda x: x.to(self.gpu_id), data)
data = tuple(map(lambda x: x.to(self.gpu_id), data))
else:
data = map(lambda x: x.to(self.device_type), data)
data = tuple(map(lambda x: x.to(self.device_type), data))
batch_loss = self._run_batch(data)
epoch_losses.append(batch_loss)
if self.log_wandb:
Expand Down Expand Up @@ -118,6 +118,13 @@ def __init__(self, model: Module, train_data: Dataset, loss_func: Callable[...,
super().__init__(model, train_data, loss_func, optimizer, gpu_id, batch_size, save_every, checkpoint_folder, device_type, log_wandb)

def _run_batch(self, data):
"""Run a data batch.
Parameters
----------
data
tuple of training batch and targets
"""
source, targets = data
self.optimizer.zero_grad()
pred = self.model(source)
Expand All @@ -131,5 +138,17 @@ def __init__(self, model: Module, train_data: Dataset, loss_func: Callable[...,
super().__init__(model, train_data, loss_func, optimizer, gpu_id, batch_size, save_every, checkpoint_folder, device_type, log_wandb)

def _run_batch(self, data):
"""Run a data batch.
Parameters
----------
data
single item tuple of training batch
"""
self.optimizer.zero_grad()
raise NotImplementedError("not finished yet")
### to be changed!
pred = self.model(data[0])
loss = self.loss_func(*pred)
loss.backward()
self.optimizer.step()
return loss.item()
9 changes: 9 additions & 0 deletions tests/tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import context
from utils.datasets import UnconditionedCifar10Dataset
from torch.utils.data import DataLoader

ds = UnconditionedCifar10Dataset("./data")
dl = DataLoader(ds, batch_size=10)

k = next(iter(dl))
print(type(k))
4 changes: 2 additions & 2 deletions tests/train_discriminative.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import torch.multiprocessing as mp
import os
from utils.mp_setup import DDP_Proc_Group
from utils.dataloaders import MNISTTrainLoader
from utils.datasets import MNISTTrainDataset
from utils.helpers import dotdict
import wandb
import torch.nn.functional as F
Expand All @@ -20,7 +20,7 @@
batch_size = 1000,
learning_rate = 0.001,
device_type = "cpu",
dataloader = MNISTTrainLoader,
dataloader = MNISTTrainDataset,
architecture = MNISTEncoder,
out_classes = 10,
optimizer = torch.optim.Adam,
Expand Down
97 changes: 97 additions & 0 deletions tests/train_generative.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import context
from torchvision.transforms import ToTensor, Compose, Normalize
from torch.utils.data import DataLoader
import torch
import torch.nn as nn
from models.mnist_enc import MNISTEncoder
from models.unet import UNet
from models.diffusion import DiffusionModel, ForwardDiffusion
import numpy as np
from time import time
from utils.trainer import DiscriminativeTrainer, GenerativeTrainer
import torch.multiprocessing as mp
import os
from utils.mp_setup import DDP_Proc_Group
from utils.datasets import MNISTTrainDataset, UnconditionedCifar10Dataset
from utils.helpers import dotdict
import wandb
import torch.nn.functional as F

config = dotdict(
total_epochs = 2,
batch_size = 1000,
learning_rate = 0.001,
device_type = "cpu",
dataset = MNISTTrainDataset,
architecture = DiffusionModel,
backbone = UNet,
in_channels = 1,
backbone_enc_depth = 4,
kernel_size = 3,
dropout = 0.5,
forward_diff = ForwardDiffusion,
max_timesteps = 1000,
t_start = 0.0001,
t_end = 0.02,
schedule_type = "linear",
time_enc_dim = 256,
optimizer = torch.optim.Adam,
data_path = os.path.abspath("./data"),
checkpoint_folder = os.path.abspath(os.path.join("./data/checkpoints")),
#data_path = "/itet-stor/peerli/net_scratch",
#checkpoint_folder = "/itet-stor/peerli/net_scratch/mnist_checkpoints",
save_every = 10,
loss_func = F.mse_loss,
log_wandb = False
)

backbone = UNet(4)
fwd_diff = ForwardDiffusion(timesteps=1000)
model = DiffusionModel(backbone, fwd_diff)

def load_train_objs(config):
train_set = config.dataset(config.data_path)
model = config.architecture(
config.backbone(
num_encoding_blocks = config.backbone_enc_depth,
in_channels = config.in_channels,
kernel_size = config.kernel_size,
dropout = config.dropout,
time_emb_size = config.time_enc_dim
),
config.forward_diff(
config.max_timesteps,
config.t_start,
config.t_end,
config.schedule_type
),
config.time_enc_dim
)
optimizer = config.optimizer(model.parameters(), lr=config.learning_rate)
return train_set, model, optimizer

def training(rank, world_size, config):
if (rank == 0) and (config.log_wandb):
wandb.init(project="mnist_trials", config=config, save_code=True)
dataset, model, optimizer = load_train_objs(config)
trainer = GenerativeTrainer(
model,
dataset,
config.loss_func,
optimizer,
rank,
config.batch_size,
config.save_every,
config.checkpoint_folder,
config.device_type,
config.log_wandb
)
trainer.train(config.total_epochs)

if __name__ == "__main__":
if config.device_type == "cuda":
world_size = torch.cuda.device_count()
print("Device Count:", world_size)
mp.spawn(DDP_Proc_Group(training), args=(world_size, config), nprocs=world_size)
else:
training(0, 0, config)
Loading

0 comments on commit 2d9c488

Please sign in to comment.