Skip to content

Commit 335dc49

Browse files
committed
added openai unet
1 parent 34a4ef1 commit 335dc49

17 files changed

+1684
-35
lines changed
+119
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
import torch
2+
from torch import nn, Tensor
3+
from jaxtyping import Float, Int64, Int
4+
from typing import Literal, Tuple, Union, List
5+
from diffusion_models.models.positional_encoding import PositionalEncoding
6+
import math
7+
from diffusion_models.models.unet import UNet
8+
from diffusion_models.models.openai_unet import UNetModel
9+
from diffusion_models.models.diffusion import ForwardDiffusion
10+
11+
class DiffusionModelOpenAI(nn.Module):
12+
"""DiffusionModel class that implements a DDPM (denoising diffusion probabilistic model)."""
13+
def __init__(
14+
self,
15+
backbone: UNet,
16+
fwd_diff: ForwardDiffusion,
17+
img_size: int,
18+
time_enc_dim: int=256,
19+
dropout: float=0,
20+
) -> None:
21+
"""Constructor of DiffusionModel class.
22+
23+
Parameters
24+
----------
25+
backbone
26+
backbone module (instance) for noise estimation
27+
fwd_diff
28+
forward diffusion module (instance)
29+
img_size
30+
size of (quadratic) images
31+
time_enc_dim
32+
feature dimension that should be used for time embedding/encoding
33+
dropout
34+
value of dropout layers
35+
"""
36+
super().__init__()
37+
self.model = backbone
38+
self.fwd_diff = fwd_diff
39+
self.img_size = img_size
40+
self.time_enc_dim = time_enc_dim
41+
self.dropout = dropout
42+
43+
self.time_encoder = PositionalEncoding(d_model=time_enc_dim, dropout=dropout)
44+
45+
def forward(
46+
self,
47+
x: Float[Tensor, "batch channels height width"]
48+
) -> Tuple[Float[Tensor, "batch channels height width"], Float[Tensor, "batch channels height width"]]:
49+
"""Predict noise for single denoising step.
50+
51+
Parameters
52+
----------
53+
x
54+
batch of original images
55+
56+
Returns
57+
-------
58+
out
59+
tuple of noise predictions and noise for random timesteps in the denoising process
60+
"""
61+
timesteps = self._sample_timesteps(x.shape[0], device=x.device)
62+
if timesteps.dim() != 1:
63+
raise ValueError("Timesteps should only have batch dimension.", timesteps.shape)
64+
x_t, noise = self.fwd_diff(x, timesteps)
65+
# predict the applied noise from the noisy version
66+
noise_pred = self.model(x_t, timesteps/self.fwd_diff.timesteps)
67+
return noise_pred, noise
68+
69+
def init_noise(self, num_samples: int):
70+
return torch.randn((num_samples, self.model.in_channels, self.img_size, self.img_size), device=list(self.parameters())[0].device)
71+
72+
def denoise_singlestep(
73+
self,
74+
x: Float[Tensor, "batch channels height width"],
75+
t: Int64[Tensor, "batch"]
76+
) -> Float[Tensor, "batch channels height width"]:
77+
"""Denoise single timestep in reverse direction.
78+
79+
Parameters
80+
----------
81+
x
82+
tensor representing a batch of noisy pictures (may be of different timesteps)
83+
t
84+
tensor representing the t timesteps for the batch (where the batch now is)
85+
86+
Returns
87+
-------
88+
out
89+
less noisy version (by one timestep, now at t-1 from the arguments)
90+
"""
91+
self.model.eval()
92+
with torch.no_grad():
93+
t_enc = self.time_encoder.get_pos_encoding(t)
94+
noise_pred = self.model(x, t_enc)
95+
alpha = self.fwd_diff.alphas[t][:, None, None, None]
96+
alpha_hat = self.fwd_diff.alphas_dash[t][:, None, None, None]
97+
beta = self.fwd_diff.betas[t][:, None, None, None]
98+
noise = torch.randn_like(x, device=noise_pred.device)
99+
# noise where t = 1 should be zero
100+
(t_one_idx, ) = torch.where(t==1)
101+
noise[t_one_idx] = 0
102+
x = 1 / torch.sqrt(alpha) * (x - ((1 - alpha) / (torch.sqrt(1 - alpha_hat))) * noise_pred) + torch.sqrt(beta) * noise
103+
self.model.train()
104+
return x
105+
106+
def sample(
107+
self,
108+
num_samples: int
109+
) -> Union[Float[Tensor, "batch channel height width"], Tuple]:
110+
beta = self.fwd_diff.betas[-1].view(-1,1,1,1)
111+
x = self.init_noise(num_samples) * torch.sqrt(beta)
112+
intermediates = {}
113+
for i in reversed(range(1, self.fwd_diff.timesteps)):
114+
t = i * torch.ones((num_samples), dtype=torch.long, device=list(self.model.parameters())[0].device)
115+
x = self.denoise_singlestep(x, t)
116+
return x
117+
118+
def _sample_timesteps(self, batch_size: int, device: torch.device) -> Float[Tensor, "batch"]:
119+
return torch.randint(low=1, high=self.fwd_diff.timesteps, size=(batch_size,), device=device)

diffusion_models/models/fp16_util.py

+76
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
"""
2+
Helpers to train with 16-bit precision.
3+
"""
4+
5+
import torch.nn as nn
6+
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
7+
8+
9+
def convert_module_to_f16(l):
10+
"""
11+
Convert primitive modules to float16.
12+
"""
13+
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
14+
l.weight.data = l.weight.data.half()
15+
l.bias.data = l.bias.data.half()
16+
17+
18+
def convert_module_to_f32(l):
19+
"""
20+
Convert primitive modules to float32, undoing convert_module_to_f16().
21+
"""
22+
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
23+
l.weight.data = l.weight.data.float()
24+
l.bias.data = l.bias.data.float()
25+
26+
27+
def make_master_params(model_params):
28+
"""
29+
Copy model parameters into a (differently-shaped) list of full-precision
30+
parameters.
31+
"""
32+
master_params = _flatten_dense_tensors(
33+
[param.detach().float() for param in model_params]
34+
)
35+
master_params = nn.Parameter(master_params)
36+
master_params.requires_grad = True
37+
return [master_params]
38+
39+
40+
def model_grads_to_master_grads(model_params, master_params):
41+
"""
42+
Copy the gradients from the model parameters into the master parameters
43+
from make_master_params().
44+
"""
45+
master_params[0].grad = _flatten_dense_tensors(
46+
[param.grad.data.detach().float() for param in model_params]
47+
)
48+
49+
50+
def master_params_to_model_params(model_params, master_params):
51+
"""
52+
Copy the master parameter data back into the model parameters.
53+
"""
54+
# Without copying to a list, if a generator is passed, this will
55+
# silently not copy any parameters.
56+
model_params = list(model_params)
57+
58+
for param, master_param in zip(
59+
model_params, unflatten_master_params(model_params, master_params)
60+
):
61+
param.detach().copy_(master_param)
62+
63+
64+
def unflatten_master_params(model_params, master_params):
65+
"""
66+
Unflatten the master parameters to look like model_params.
67+
"""
68+
return _unflatten_dense_tensors(master_params[0].detach(), tuple(tensor for tensor in model_params))
69+
70+
71+
def zero_grad(model_params):
72+
for param in model_params:
73+
# Taken from https://pytorch.org/docs/stable/_modules/torch/optim/optimizer.html#Optimizer.add_param_group
74+
if param.grad is not None:
75+
param.grad.detach_()
76+
param.grad.zero_()

diffusion_models/models/nn.py

+170
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
"""
2+
Various utilities for neural networks.
3+
"""
4+
5+
import math
6+
7+
import torch as th
8+
import torch.nn as nn
9+
10+
11+
# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
12+
class SiLU(nn.Module):
13+
def forward(self, x):
14+
return x * th.sigmoid(x)
15+
16+
17+
class GroupNorm32(nn.GroupNorm):
18+
def forward(self, x):
19+
return super().forward(x.float()).type(x.dtype)
20+
21+
22+
def conv_nd(dims, *args, **kwargs):
23+
"""
24+
Create a 1D, 2D, or 3D convolution module.
25+
"""
26+
if dims == 1:
27+
return nn.Conv1d(*args, **kwargs)
28+
elif dims == 2:
29+
return nn.Conv2d(*args, **kwargs)
30+
elif dims == 3:
31+
return nn.Conv3d(*args, **kwargs)
32+
raise ValueError(f"unsupported dimensions: {dims}")
33+
34+
35+
def linear(*args, **kwargs):
36+
"""
37+
Create a linear module.
38+
"""
39+
return nn.Linear(*args, **kwargs)
40+
41+
42+
def avg_pool_nd(dims, *args, **kwargs):
43+
"""
44+
Create a 1D, 2D, or 3D average pooling module.
45+
"""
46+
if dims == 1:
47+
return nn.AvgPool1d(*args, **kwargs)
48+
elif dims == 2:
49+
return nn.AvgPool2d(*args, **kwargs)
50+
elif dims == 3:
51+
return nn.AvgPool3d(*args, **kwargs)
52+
raise ValueError(f"unsupported dimensions: {dims}")
53+
54+
55+
def update_ema(target_params, source_params, rate=0.99):
56+
"""
57+
Update target parameters to be closer to those of source parameters using
58+
an exponential moving average.
59+
60+
:param target_params: the target parameter sequence.
61+
:param source_params: the source parameter sequence.
62+
:param rate: the EMA rate (closer to 1 means slower).
63+
"""
64+
for targ, src in zip(target_params, source_params):
65+
targ.detach().mul_(rate).add_(src, alpha=1 - rate)
66+
67+
68+
def zero_module(module):
69+
"""
70+
Zero out the parameters of a module and return it.
71+
"""
72+
for p in module.parameters():
73+
p.detach().zero_()
74+
return module
75+
76+
77+
def scale_module(module, scale):
78+
"""
79+
Scale the parameters of a module and return it.
80+
"""
81+
for p in module.parameters():
82+
p.detach().mul_(scale)
83+
return module
84+
85+
86+
def mean_flat(tensor):
87+
"""
88+
Take the mean over all non-batch dimensions.
89+
"""
90+
return tensor.mean(dim=list(range(1, len(tensor.shape))))
91+
92+
93+
def normalization(channels):
94+
"""
95+
Make a standard normalization layer.
96+
97+
:param channels: number of input channels.
98+
:return: an nn.Module for normalization.
99+
"""
100+
return GroupNorm32(32, channels)
101+
102+
103+
def timestep_embedding(timesteps, dim, max_period=10000):
104+
"""
105+
Create sinusoidal timestep embeddings.
106+
107+
:param timesteps: a 1-D Tensor of N indices, one per batch element.
108+
These may be fractional.
109+
:param dim: the dimension of the output.
110+
:param max_period: controls the minimum frequency of the embeddings.
111+
:return: an [N x dim] Tensor of positional embeddings.
112+
"""
113+
half = dim // 2
114+
freqs = th.exp(
115+
-math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half
116+
).to(device=timesteps.device)
117+
args = timesteps[:, None].float() * freqs[None]
118+
embedding = th.cat([th.cos(args), th.sin(args)], dim=-1)
119+
if dim % 2:
120+
embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1)
121+
return embedding
122+
123+
124+
def checkpoint(func, inputs, params, flag):
125+
"""
126+
Evaluate a function without caching intermediate activations, allowing for
127+
reduced memory at the expense of extra compute in the backward pass.
128+
129+
:param func: the function to evaluate.
130+
:param inputs: the argument sequence to pass to `func`.
131+
:param params: a sequence of parameters `func` depends on but does not
132+
explicitly take as arguments.
133+
:param flag: if False, disable gradient checkpointing.
134+
"""
135+
if flag:
136+
args = tuple(inputs) + tuple(params)
137+
return CheckpointFunction.apply(func, len(inputs), *args)
138+
else:
139+
return func(*inputs)
140+
141+
142+
class CheckpointFunction(th.autograd.Function):
143+
@staticmethod
144+
def forward(ctx, run_function, length, *args):
145+
ctx.run_function = run_function
146+
ctx.input_tensors = list(args[:length])
147+
ctx.input_params = list(args[length:])
148+
with th.no_grad():
149+
output_tensors = ctx.run_function(*ctx.input_tensors)
150+
return output_tensors
151+
152+
@staticmethod
153+
def backward(ctx, *output_grads):
154+
ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
155+
with th.enable_grad():
156+
# Fixes a bug where the first op in run_function modifies the
157+
# Tensor storage in place, which is not allowed for detach()'d
158+
# Tensors.
159+
shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
160+
output_tensors = ctx.run_function(*shallow_copies)
161+
input_grads = th.autograd.grad(
162+
output_tensors,
163+
ctx.input_tensors + ctx.input_params,
164+
output_grads,
165+
allow_unused=True,
166+
)
167+
del ctx.input_tensors
168+
del ctx.input_params
169+
del output_tensors
170+
return (None, None) + input_grads

0 commit comments

Comments
 (0)