Skip to content

Commit

Permalink
added cosine scheduler
Browse files Browse the repository at this point in the history
  • Loading branch information
liopeer committed Oct 8, 2023
1 parent f3cc9aa commit ab1dca8
Show file tree
Hide file tree
Showing 11 changed files with 713 additions and 95 deletions.
35 changes: 14 additions & 21 deletions diffusion_models/models/diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,25 +19,31 @@ class ForwardDiffusion(nn.Module):
type
type of scheduler, currently linear and cosine supported
"""
def __init__(self, timesteps: int, start: float=0.0001, end: float=0.02, type: Literal["linear", "cosine"]="linear") -> None:
def __init__(self, timesteps: int, start: float=0.0001, end: float=0.02, offset: float=0.008, max_beta: float=0.999, type: Literal["linear", "cosine"]="linear") -> None:
"""Constructor of ForwardDiffusion class.
Parameters
----------
timesteps
timesteps
total number of timesteps in diffusion process
start
start
start beta for linear scheduler
end
end
end beta for linear scheduler
offset
offset parameter for cosine scheduler
max_beta
maximal value to clip betas for cosine scheduler
type
type
type of scheduler, either linear or cosine
"""
super().__init__()
self.timesteps = timesteps
self.start = start
self.end = end
self.offset = offset
self.type = type
self.max_beta = max_beta
if self.type == "linear":
self.init_betas = self._linear_scheduler(timesteps=self.timesteps, start=self.start, end=self.end)
elif self.type == "cosine":
Expand Down Expand Up @@ -84,26 +90,13 @@ def _linear_scheduler(self, timesteps, start, end):
return torch.linspace(start, end, timesteps)

def _cosine_scheduler(self, timesteps, start, end):
return self._betas_for_alpha_bar(
timesteps,
lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
)
"""t is actually t/T from the paper"""
return self._betas_for_alpha_bar(timesteps, lambda t: math.cos((t + self.offset) / (1.0 + self.offset) * math.pi / 2) ** 2, self.max_beta)

def _betas_for_alpha_bar(self, num_diffusion_timesteps, alpha_bar, max_beta=0.999):
"""
Create a beta schedule that discretizes the given alpha_t_bar function,
which defines the cumulative product of (1-beta) over time from t = [0,1].
:param num_diffusion_timesteps: the number of betas to produce.
:param alpha_bar: a lambda that takes an argument t from 0 to 1 and
produces the cumulative product of (1-beta) up to that
part of the diffusion process.
:param max_beta: the maximum beta to use; use values lower than 1 to
prevent singularities.
"""
betas = []
for i in range(num_diffusion_timesteps):
t1 = i / num_diffusion_timesteps
t1 = i / num_diffusion_timesteps # t -> t/T
t2 = (i + 1) / num_diffusion_timesteps
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
return torch.tensor(betas)
Expand Down
2 changes: 1 addition & 1 deletion diffusion_models/utils/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def __init__(self, root: str, train: bool = True, transform: Callable[..., Any]

class MNISTTestDataset(MNIST):
def __init__(self, root: str, train: bool = True, transform: Callable[..., Any] | None = None, target_transform: Callable[..., Any] | None = None, download: bool = False) -> None:
transform = Compose([ToTensor(), Normalize((0.1307,), (0.3081,))])
transform = Compose([ToTensor(), Normalize((0.1307,), (0.3081,)), Resize((32,32), antialias=True)])
download = True
super().__init__(root, train, transform, target_transform, download)

Expand Down
6 changes: 6 additions & 0 deletions diffusion_models/utils/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,12 @@ def train(self, max_epochs: int):
if (self.gpu_id == 0) and (epoch % self.save_every == 0) and (epoch != 0):
self._save_checkpoint(epoch)

def load_checkpoint(self, checkpoint_path: str):
if self.device_type == "cuda":
self.model.module.load_state_dict(torch.load(checkpoint_path))
else:
self.model.load_state_dict(torch.load(checkpoint_path))

class DiscriminativeTrainer(Trainer):
def __init__(self, model: Module, train_data: Dataset, loss_func: Callable[..., Any], optimizer: Optimizer, gpu_id: int, batch_size: int, save_every: int, checkpoint_folder: str, device_type: Literal['cuda', 'mps', 'cpu'], log_wandb: bool = True) -> None:
super().__init__(model, train_data, loss_func, optimizer, gpu_id, batch_size, save_every, checkpoint_folder, device_type, log_wandb)
Expand Down
15 changes: 8 additions & 7 deletions examples/forward_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,18 @@
import torch.nn as nn
import context
from models.diffusion import ForwardDiffusion
from torchvision.transforms import Compose, Normalize, Resize
import torchvision

mode = "cosine"
timesteps = 1000
every = 100
timesteps = 1200
every = 200

#############################################################################

img = "/Users/lionelpeer/Pictures/2020/Japan/darktable_exported/DSC_1808.jpg"
img = read_image(img) / 255
transform = torchvision.transforms.Resize((120, 180))
transform = transform = Compose([Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), Resize((120, 180))])
img = transform(img)

print("image shape:", img.shape)
Expand All @@ -37,17 +38,17 @@
noiser = ForwardDiffusion(timesteps=timesteps, type=mode).to(device)
batch = batch.to(device)

noisies, noise = noiser.forward(batch[0], torch.tensor([i*every-1 for i in range(1,timesteps//every)]))
noisies, noise = noiser.forward(batch[0], torch.tensor([i*every for i in range(0,timesteps//every)]))
noisies = [noisies[i].permute(1,2,0) for i in range(noisies.shape[0])]

blub = [r"$x_{} \sim q(x_{})$".format("{"+str(0)+"}", "{"+str(0)+"}")]
titles = [
r"$x_{} \sim q(x_{}\mid x_{})$".format("{"+str(i)+"}", "{"+str(i)+"}", "{"+str(i-1)+"}") for i in [j*every for j in range(1, timesteps//every)]
r"$x_{} \sim q(x_{}\mid x_{})$".format("{"+str(i)+"}", "{"+str(i)+"}", "{"+str(i-1)+"}") for i in [j*every for j in range(1, timesteps//every+1)]
]
blub.extend(titles)

fig, ax = plt.subplots(1,timesteps//every-1,figsize=(25,5))
for i, (elem, title) in enumerate(zip(noisies[:-1], blub[:-1])):
fig, ax = plt.subplots(1,timesteps//every,figsize=(25,5))
for i, (elem, title) in enumerate(zip(noisies[:], blub[:])):
ax[i].imshow(elem.cpu())
ax[i].axis("off")
ax[i].set_title(title)
Expand Down
Binary file modified examples/img/forward_naoshima_cosine.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified examples/img/forward_naoshima_linear.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading

0 comments on commit ab1dca8

Please sign in to comment.