Skip to content

Commit

Permalink
updates to forward diff
Browse files Browse the repository at this point in the history
  • Loading branch information
Lionel Peer committed Sep 2, 2023
1 parent 4057664 commit 16cc8f4
Show file tree
Hide file tree
Showing 3 changed files with 155 additions and 46 deletions.
86 changes: 63 additions & 23 deletions diffusion_models/models/forward_diffusion.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,80 @@
import torch
from torch import nn
from torch import nn, Tensor
from jaxtyping import Float
from typing import Literal

class ForwardDiffusion(nn.Module):
def __init__(self, timesteps: int, start: float=0.0001, end: float=0.02, random_seed: int=42, type="linear", device=None) -> None:
"""Class for forward diffusion process in DDPMs (denoising diffusion probabilistic models).
Attributes
----------
timesteps
max number of supported timesteps of the schedule
start
start value of scheduler
end
end value of scheduler
type
type of scheduler, currently linear and cosine supported
"""
def __init__(self, timesteps: int, start: float=0.0001, end: float=0.02, type: Literal["linear", "cosine"]="linear") -> None:
"""Constructor of ForwardDiffusion class.
Parameters
----------
timesteps
timesteps
start
start
end
end
type
type
"""
super().__init__()
self.timesteps = timesteps
self.start = start
self.end = end
self.random_seed = random_seed
torch.manual_seed(self.random_seed)
self.type = type
if device is None:
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
else:
self.device = device
if self.type == "linear":
self.betas = self.linear_schedule(timesteps=self.timesteps, start=self.start, end=self.end)
self.betas = self._linear_scheduler(timesteps=self.timesteps, start=self.start, end=self.end)
elif self.type == "cosine":
raise NotImplementedError("Cosine scheduler not implemented yet.")
self.betas = self._cosine_scheduler(timesteps=self.timesteps, start=self.start, end=self.end)
else:
raise NotImplementedError("Invalid scheduler option:", type)
self.alphas = 1. - self.betas
self.alphas_dash = torch.cumprod(self.alphas, axis=0)
self.sqrt_alphas_dash = torch.sqrt(self.alphas_dash)
self.sqrt_one_minus_alpha_dash = torch.sqrt(1. - self.alphas_dash)

def forward(self, x_0: torch.Tensor, t: int):
noise_normal = torch.randn_like(x_0).to(self.device)
sqrt_alpha_dash_t = self.sqrt_alphas_dash[t].to(self.device)
sqrt_one_minus_alpha_dash_t = self.sqrt_one_minus_alpha_dash[t].to(self.device)
if x_0.device != self.device:
x_0 = x_0.to(self.device)
x_t = sqrt_alpha_dash_t * x_0 + sqrt_one_minus_alpha_dash_t * noise_normal
self.register_buffer("alphas_dash", torch.cumprod(self.alphas, axis=0), persistent=False)
self.register_buffer("sqrt_alphas_dash", torch.sqrt(self.alphas_dash), persistent=False)
self.register_buffer("sqrt_one_minus_alpha_dash", 1. - self.alphas_dash, persistent=False)

self.register_buffer("noise_normal", torch.empty((1)), persistent=False)

def forward(self, x_0: Float[Tensor, "batch channels height width"], t: int) -> Float[Tensor, "batch channels height width"]:
"""Forward method of ForwardDiffusion class.
Parameters
----------
x_0
input tensor where noise should be applied to
t
timestep of the noise scheduler from which noise should be chosen
Returns
-------
Float[Tensor, "batch channels height width"]
tensor with applied noise according to schedule and chosen timestep
"""
self.noise_normal = torch.randn_like(x_0)
if t > self.timesteps-1:
raise IndexError("t ({}) chosen larger than max. available t ({})".format(t, self.timesteps-1))
sqrt_alpha_dash_t = self.sqrt_alphas_dash[t]
sqrt_one_minus_alpha_dash_t = self.sqrt_one_minus_alpha_dash[t]
x_t = sqrt_alpha_dash_t * x_0 + sqrt_one_minus_alpha_dash_t * self.noise_normal
return x_t

def linear_schedule(self, timesteps, start, end):
def _linear_scheduler(self, timesteps, start, end):
return torch.linspace(start, end, timesteps)

def cosine_scheduler(self, timesteps, start, end):
pass
def _cosine_scheduler(self, timesteps, start, end):
raise NotImplementedError("Cosine scheduler not implemented yet.")
8 changes: 4 additions & 4 deletions docs/source/fig/drawing.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
107 changes: 88 additions & 19 deletions examples/forward_process.ipynb

Large diffs are not rendered by default.

0 comments on commit 16cc8f4

Please sign in to comment.