Skip to content

Commit ab1dca8

Browse files
committed
added cosine scheduler
1 parent f3cc9aa commit ab1dca8

11 files changed

+713
-95
lines changed

diffusion_models/models/diffusion.py

Lines changed: 14 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -19,25 +19,31 @@ class ForwardDiffusion(nn.Module):
1919
type
2020
type of scheduler, currently linear and cosine supported
2121
"""
22-
def __init__(self, timesteps: int, start: float=0.0001, end: float=0.02, type: Literal["linear", "cosine"]="linear") -> None:
22+
def __init__(self, timesteps: int, start: float=0.0001, end: float=0.02, offset: float=0.008, max_beta: float=0.999, type: Literal["linear", "cosine"]="linear") -> None:
2323
"""Constructor of ForwardDiffusion class.
2424
2525
Parameters
2626
----------
2727
timesteps
28-
timesteps
28+
total number of timesteps in diffusion process
2929
start
30-
start
30+
start beta for linear scheduler
3131
end
32-
end
32+
end beta for linear scheduler
33+
offset
34+
offset parameter for cosine scheduler
35+
max_beta
36+
maximal value to clip betas for cosine scheduler
3337
type
34-
type
38+
type of scheduler, either linear or cosine
3539
"""
3640
super().__init__()
3741
self.timesteps = timesteps
3842
self.start = start
3943
self.end = end
44+
self.offset = offset
4045
self.type = type
46+
self.max_beta = max_beta
4147
if self.type == "linear":
4248
self.init_betas = self._linear_scheduler(timesteps=self.timesteps, start=self.start, end=self.end)
4349
elif self.type == "cosine":
@@ -84,26 +90,13 @@ def _linear_scheduler(self, timesteps, start, end):
8490
return torch.linspace(start, end, timesteps)
8591

8692
def _cosine_scheduler(self, timesteps, start, end):
87-
return self._betas_for_alpha_bar(
88-
timesteps,
89-
lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
90-
)
93+
"""t is actually t/T from the paper"""
94+
return self._betas_for_alpha_bar(timesteps, lambda t: math.cos((t + self.offset) / (1.0 + self.offset) * math.pi / 2) ** 2, self.max_beta)
9195

9296
def _betas_for_alpha_bar(self, num_diffusion_timesteps, alpha_bar, max_beta=0.999):
93-
"""
94-
Create a beta schedule that discretizes the given alpha_t_bar function,
95-
which defines the cumulative product of (1-beta) over time from t = [0,1].
96-
97-
:param num_diffusion_timesteps: the number of betas to produce.
98-
:param alpha_bar: a lambda that takes an argument t from 0 to 1 and
99-
produces the cumulative product of (1-beta) up to that
100-
part of the diffusion process.
101-
:param max_beta: the maximum beta to use; use values lower than 1 to
102-
prevent singularities.
103-
"""
10497
betas = []
10598
for i in range(num_diffusion_timesteps):
106-
t1 = i / num_diffusion_timesteps
99+
t1 = i / num_diffusion_timesteps # t -> t/T
107100
t2 = (i + 1) / num_diffusion_timesteps
108101
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
109102
return torch.tensor(betas)

diffusion_models/utils/datasets.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def __init__(self, root: str, train: bool = True, transform: Callable[..., Any]
2222

2323
class MNISTTestDataset(MNIST):
2424
def __init__(self, root: str, train: bool = True, transform: Callable[..., Any] | None = None, target_transform: Callable[..., Any] | None = None, download: bool = False) -> None:
25-
transform = Compose([ToTensor(), Normalize((0.1307,), (0.3081,))])
25+
transform = Compose([ToTensor(), Normalize((0.1307,), (0.3081,)), Resize((32,32), antialias=True)])
2626
download = True
2727
super().__init__(root, train, transform, target_transform, download)
2828

diffusion_models/utils/trainer.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,12 @@ def train(self, max_epochs: int):
117117
if (self.gpu_id == 0) and (epoch % self.save_every == 0) and (epoch != 0):
118118
self._save_checkpoint(epoch)
119119

120+
def load_checkpoint(self, checkpoint_path: str):
121+
if self.device_type == "cuda":
122+
self.model.module.load_state_dict(torch.load(checkpoint_path))
123+
else:
124+
self.model.load_state_dict(torch.load(checkpoint_path))
125+
120126
class DiscriminativeTrainer(Trainer):
121127
def __init__(self, model: Module, train_data: Dataset, loss_func: Callable[..., Any], optimizer: Optimizer, gpu_id: int, batch_size: int, save_every: int, checkpoint_folder: str, device_type: Literal['cuda', 'mps', 'cpu'], log_wandb: bool = True) -> None:
122128
super().__init__(model, train_data, loss_func, optimizer, gpu_id, batch_size, save_every, checkpoint_folder, device_type, log_wandb)

examples/forward_process.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,17 +11,18 @@
1111
import torch.nn as nn
1212
import context
1313
from models.diffusion import ForwardDiffusion
14+
from torchvision.transforms import Compose, Normalize, Resize
1415
import torchvision
1516

1617
mode = "cosine"
17-
timesteps = 1000
18-
every = 100
18+
timesteps = 1200
19+
every = 200
1920

2021
#############################################################################
2122

2223
img = "/Users/lionelpeer/Pictures/2020/Japan/darktable_exported/DSC_1808.jpg"
2324
img = read_image(img) / 255
24-
transform = torchvision.transforms.Resize((120, 180))
25+
transform = transform = Compose([Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), Resize((120, 180))])
2526
img = transform(img)
2627

2728
print("image shape:", img.shape)
@@ -37,17 +38,17 @@
3738
noiser = ForwardDiffusion(timesteps=timesteps, type=mode).to(device)
3839
batch = batch.to(device)
3940

40-
noisies, noise = noiser.forward(batch[0], torch.tensor([i*every-1 for i in range(1,timesteps//every)]))
41+
noisies, noise = noiser.forward(batch[0], torch.tensor([i*every for i in range(0,timesteps//every)]))
4142
noisies = [noisies[i].permute(1,2,0) for i in range(noisies.shape[0])]
4243

4344
blub = [r"$x_{} \sim q(x_{})$".format("{"+str(0)+"}", "{"+str(0)+"}")]
4445
titles = [
45-
r"$x_{} \sim q(x_{}\mid x_{})$".format("{"+str(i)+"}", "{"+str(i)+"}", "{"+str(i-1)+"}") for i in [j*every for j in range(1, timesteps//every)]
46+
r"$x_{} \sim q(x_{}\mid x_{})$".format("{"+str(i)+"}", "{"+str(i)+"}", "{"+str(i-1)+"}") for i in [j*every for j in range(1, timesteps//every+1)]
4647
]
4748
blub.extend(titles)
4849

49-
fig, ax = plt.subplots(1,timesteps//every-1,figsize=(25,5))
50-
for i, (elem, title) in enumerate(zip(noisies[:-1], blub[:-1])):
50+
fig, ax = plt.subplots(1,timesteps//every,figsize=(25,5))
51+
for i, (elem, title) in enumerate(zip(noisies[:], blub[:])):
5152
ax[i].imshow(elem.cpu())
5253
ax[i].axis("off")
5354
ax[i].set_title(title)
157 KB
Loading
156 KB
Loading

0 commit comments

Comments
 (0)