-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Lionel Peer
committed
Sep 2, 2023
1 parent
4057664
commit 16cc8f4
Showing
3 changed files
with
155 additions
and
46 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,40 +1,80 @@ | ||
import torch | ||
from torch import nn | ||
from torch import nn, Tensor | ||
from jaxtyping import Float | ||
from typing import Literal | ||
|
||
class ForwardDiffusion(nn.Module): | ||
def __init__(self, timesteps: int, start: float=0.0001, end: float=0.02, random_seed: int=42, type="linear", device=None) -> None: | ||
"""Class for forward diffusion process in DDPMs (denoising diffusion probabilistic models). | ||
Attributes | ||
---------- | ||
timesteps | ||
max number of supported timesteps of the schedule | ||
start | ||
start value of scheduler | ||
end | ||
end value of scheduler | ||
type | ||
type of scheduler, currently linear and cosine supported | ||
""" | ||
def __init__(self, timesteps: int, start: float=0.0001, end: float=0.02, type: Literal["linear", "cosine"]="linear") -> None: | ||
"""Constructor of ForwardDiffusion class. | ||
Parameters | ||
---------- | ||
timesteps | ||
timesteps | ||
start | ||
start | ||
end | ||
end | ||
type | ||
type | ||
""" | ||
super().__init__() | ||
self.timesteps = timesteps | ||
self.start = start | ||
self.end = end | ||
self.random_seed = random_seed | ||
torch.manual_seed(self.random_seed) | ||
self.type = type | ||
if device is None: | ||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
else: | ||
self.device = device | ||
if self.type == "linear": | ||
self.betas = self.linear_schedule(timesteps=self.timesteps, start=self.start, end=self.end) | ||
self.betas = self._linear_scheduler(timesteps=self.timesteps, start=self.start, end=self.end) | ||
elif self.type == "cosine": | ||
raise NotImplementedError("Cosine scheduler not implemented yet.") | ||
self.betas = self._cosine_scheduler(timesteps=self.timesteps, start=self.start, end=self.end) | ||
else: | ||
raise NotImplementedError("Invalid scheduler option:", type) | ||
self.alphas = 1. - self.betas | ||
self.alphas_dash = torch.cumprod(self.alphas, axis=0) | ||
self.sqrt_alphas_dash = torch.sqrt(self.alphas_dash) | ||
self.sqrt_one_minus_alpha_dash = torch.sqrt(1. - self.alphas_dash) | ||
|
||
def forward(self, x_0: torch.Tensor, t: int): | ||
noise_normal = torch.randn_like(x_0).to(self.device) | ||
sqrt_alpha_dash_t = self.sqrt_alphas_dash[t].to(self.device) | ||
sqrt_one_minus_alpha_dash_t = self.sqrt_one_minus_alpha_dash[t].to(self.device) | ||
if x_0.device != self.device: | ||
x_0 = x_0.to(self.device) | ||
x_t = sqrt_alpha_dash_t * x_0 + sqrt_one_minus_alpha_dash_t * noise_normal | ||
self.register_buffer("alphas_dash", torch.cumprod(self.alphas, axis=0), persistent=False) | ||
self.register_buffer("sqrt_alphas_dash", torch.sqrt(self.alphas_dash), persistent=False) | ||
self.register_buffer("sqrt_one_minus_alpha_dash", 1. - self.alphas_dash, persistent=False) | ||
|
||
self.register_buffer("noise_normal", torch.empty((1)), persistent=False) | ||
|
||
def forward(self, x_0: Float[Tensor, "batch channels height width"], t: int) -> Float[Tensor, "batch channels height width"]: | ||
"""Forward method of ForwardDiffusion class. | ||
Parameters | ||
---------- | ||
x_0 | ||
input tensor where noise should be applied to | ||
t | ||
timestep of the noise scheduler from which noise should be chosen | ||
Returns | ||
------- | ||
Float[Tensor, "batch channels height width"] | ||
tensor with applied noise according to schedule and chosen timestep | ||
""" | ||
self.noise_normal = torch.randn_like(x_0) | ||
if t > self.timesteps-1: | ||
raise IndexError("t ({}) chosen larger than max. available t ({})".format(t, self.timesteps-1)) | ||
sqrt_alpha_dash_t = self.sqrt_alphas_dash[t] | ||
sqrt_one_minus_alpha_dash_t = self.sqrt_one_minus_alpha_dash[t] | ||
x_t = sqrt_alpha_dash_t * x_0 + sqrt_one_minus_alpha_dash_t * self.noise_normal | ||
return x_t | ||
|
||
def linear_schedule(self, timesteps, start, end): | ||
def _linear_scheduler(self, timesteps, start, end): | ||
return torch.linspace(start, end, timesteps) | ||
|
||
def cosine_scheduler(self, timesteps, start, end): | ||
pass | ||
def _cosine_scheduler(self, timesteps, start, end): | ||
raise NotImplementedError("Cosine scheduler not implemented yet.") |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Large diffs are not rendered by default.
Oops, something went wrong.