Skip to content

Commit

Permalink
added openai unet
Browse files Browse the repository at this point in the history
  • Loading branch information
liopeer committed May 8, 2024
1 parent 34a4ef1 commit 335dc49
Show file tree
Hide file tree
Showing 17 changed files with 1,684 additions and 35 deletions.
119 changes: 119 additions & 0 deletions diffusion_models/models/diffusion_openai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
import torch
from torch import nn, Tensor
from jaxtyping import Float, Int64, Int
from typing import Literal, Tuple, Union, List
from diffusion_models.models.positional_encoding import PositionalEncoding
import math
from diffusion_models.models.unet import UNet
from diffusion_models.models.openai_unet import UNetModel
from diffusion_models.models.diffusion import ForwardDiffusion

class DiffusionModelOpenAI(nn.Module):
"""DiffusionModel class that implements a DDPM (denoising diffusion probabilistic model)."""
def __init__(
self,
backbone: UNet,
fwd_diff: ForwardDiffusion,
img_size: int,
time_enc_dim: int=256,
dropout: float=0,
) -> None:
"""Constructor of DiffusionModel class.
Parameters
----------
backbone
backbone module (instance) for noise estimation
fwd_diff
forward diffusion module (instance)
img_size
size of (quadratic) images
time_enc_dim
feature dimension that should be used for time embedding/encoding
dropout
value of dropout layers
"""
super().__init__()
self.model = backbone
self.fwd_diff = fwd_diff
self.img_size = img_size
self.time_enc_dim = time_enc_dim
self.dropout = dropout

self.time_encoder = PositionalEncoding(d_model=time_enc_dim, dropout=dropout)

def forward(
self,
x: Float[Tensor, "batch channels height width"]
) -> Tuple[Float[Tensor, "batch channels height width"], Float[Tensor, "batch channels height width"]]:
"""Predict noise for single denoising step.
Parameters
----------
x
batch of original images
Returns
-------
out
tuple of noise predictions and noise for random timesteps in the denoising process
"""
timesteps = self._sample_timesteps(x.shape[0], device=x.device)
if timesteps.dim() != 1:
raise ValueError("Timesteps should only have batch dimension.", timesteps.shape)
x_t, noise = self.fwd_diff(x, timesteps)
# predict the applied noise from the noisy version
noise_pred = self.model(x_t, timesteps/self.fwd_diff.timesteps)
return noise_pred, noise

def init_noise(self, num_samples: int):
return torch.randn((num_samples, self.model.in_channels, self.img_size, self.img_size), device=list(self.parameters())[0].device)

def denoise_singlestep(
self,
x: Float[Tensor, "batch channels height width"],
t: Int64[Tensor, "batch"]
) -> Float[Tensor, "batch channels height width"]:
"""Denoise single timestep in reverse direction.
Parameters
----------
x
tensor representing a batch of noisy pictures (may be of different timesteps)
t
tensor representing the t timesteps for the batch (where the batch now is)
Returns
-------
out
less noisy version (by one timestep, now at t-1 from the arguments)
"""
self.model.eval()
with torch.no_grad():
t_enc = self.time_encoder.get_pos_encoding(t)
noise_pred = self.model(x, t_enc)
alpha = self.fwd_diff.alphas[t][:, None, None, None]
alpha_hat = self.fwd_diff.alphas_dash[t][:, None, None, None]
beta = self.fwd_diff.betas[t][:, None, None, None]
noise = torch.randn_like(x, device=noise_pred.device)
# noise where t = 1 should be zero
(t_one_idx, ) = torch.where(t==1)
noise[t_one_idx] = 0
x = 1 / torch.sqrt(alpha) * (x - ((1 - alpha) / (torch.sqrt(1 - alpha_hat))) * noise_pred) + torch.sqrt(beta) * noise
self.model.train()
return x

def sample(
self,
num_samples: int
) -> Union[Float[Tensor, "batch channel height width"], Tuple]:
beta = self.fwd_diff.betas[-1].view(-1,1,1,1)
x = self.init_noise(num_samples) * torch.sqrt(beta)
intermediates = {}
for i in reversed(range(1, self.fwd_diff.timesteps)):
t = i * torch.ones((num_samples), dtype=torch.long, device=list(self.model.parameters())[0].device)
x = self.denoise_singlestep(x, t)
return x

def _sample_timesteps(self, batch_size: int, device: torch.device) -> Float[Tensor, "batch"]:
return torch.randint(low=1, high=self.fwd_diff.timesteps, size=(batch_size,), device=device)
76 changes: 76 additions & 0 deletions diffusion_models/models/fp16_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
"""
Helpers to train with 16-bit precision.
"""

import torch.nn as nn
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors


def convert_module_to_f16(l):
"""
Convert primitive modules to float16.
"""
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
l.weight.data = l.weight.data.half()
l.bias.data = l.bias.data.half()


def convert_module_to_f32(l):
"""
Convert primitive modules to float32, undoing convert_module_to_f16().
"""
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
l.weight.data = l.weight.data.float()
l.bias.data = l.bias.data.float()


def make_master_params(model_params):
"""
Copy model parameters into a (differently-shaped) list of full-precision
parameters.
"""
master_params = _flatten_dense_tensors(
[param.detach().float() for param in model_params]
)
master_params = nn.Parameter(master_params)
master_params.requires_grad = True
return [master_params]


def model_grads_to_master_grads(model_params, master_params):
"""
Copy the gradients from the model parameters into the master parameters
from make_master_params().
"""
master_params[0].grad = _flatten_dense_tensors(
[param.grad.data.detach().float() for param in model_params]
)


def master_params_to_model_params(model_params, master_params):
"""
Copy the master parameter data back into the model parameters.
"""
# Without copying to a list, if a generator is passed, this will
# silently not copy any parameters.
model_params = list(model_params)

for param, master_param in zip(
model_params, unflatten_master_params(model_params, master_params)
):
param.detach().copy_(master_param)


def unflatten_master_params(model_params, master_params):
"""
Unflatten the master parameters to look like model_params.
"""
return _unflatten_dense_tensors(master_params[0].detach(), tuple(tensor for tensor in model_params))


def zero_grad(model_params):
for param in model_params:
# Taken from https://pytorch.org/docs/stable/_modules/torch/optim/optimizer.html#Optimizer.add_param_group
if param.grad is not None:
param.grad.detach_()
param.grad.zero_()
170 changes: 170 additions & 0 deletions diffusion_models/models/nn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
"""
Various utilities for neural networks.
"""

import math

import torch as th
import torch.nn as nn


# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
class SiLU(nn.Module):
def forward(self, x):
return x * th.sigmoid(x)


class GroupNorm32(nn.GroupNorm):
def forward(self, x):
return super().forward(x.float()).type(x.dtype)


def conv_nd(dims, *args, **kwargs):
"""
Create a 1D, 2D, or 3D convolution module.
"""
if dims == 1:
return nn.Conv1d(*args, **kwargs)
elif dims == 2:
return nn.Conv2d(*args, **kwargs)
elif dims == 3:
return nn.Conv3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")


def linear(*args, **kwargs):
"""
Create a linear module.
"""
return nn.Linear(*args, **kwargs)


def avg_pool_nd(dims, *args, **kwargs):
"""
Create a 1D, 2D, or 3D average pooling module.
"""
if dims == 1:
return nn.AvgPool1d(*args, **kwargs)
elif dims == 2:
return nn.AvgPool2d(*args, **kwargs)
elif dims == 3:
return nn.AvgPool3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")


def update_ema(target_params, source_params, rate=0.99):
"""
Update target parameters to be closer to those of source parameters using
an exponential moving average.
:param target_params: the target parameter sequence.
:param source_params: the source parameter sequence.
:param rate: the EMA rate (closer to 1 means slower).
"""
for targ, src in zip(target_params, source_params):
targ.detach().mul_(rate).add_(src, alpha=1 - rate)


def zero_module(module):
"""
Zero out the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().zero_()
return module


def scale_module(module, scale):
"""
Scale the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().mul_(scale)
return module


def mean_flat(tensor):
"""
Take the mean over all non-batch dimensions.
"""
return tensor.mean(dim=list(range(1, len(tensor.shape))))


def normalization(channels):
"""
Make a standard normalization layer.
:param channels: number of input channels.
:return: an nn.Module for normalization.
"""
return GroupNorm32(32, channels)


def timestep_embedding(timesteps, dim, max_period=10000):
"""
Create sinusoidal timestep embeddings.
:param timesteps: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an [N x dim] Tensor of positional embeddings.
"""
half = dim // 2
freqs = th.exp(
-math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half
).to(device=timesteps.device)
args = timesteps[:, None].float() * freqs[None]
embedding = th.cat([th.cos(args), th.sin(args)], dim=-1)
if dim % 2:
embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1)
return embedding


def checkpoint(func, inputs, params, flag):
"""
Evaluate a function without caching intermediate activations, allowing for
reduced memory at the expense of extra compute in the backward pass.
:param func: the function to evaluate.
:param inputs: the argument sequence to pass to `func`.
:param params: a sequence of parameters `func` depends on but does not
explicitly take as arguments.
:param flag: if False, disable gradient checkpointing.
"""
if flag:
args = tuple(inputs) + tuple(params)
return CheckpointFunction.apply(func, len(inputs), *args)
else:
return func(*inputs)


class CheckpointFunction(th.autograd.Function):
@staticmethod
def forward(ctx, run_function, length, *args):
ctx.run_function = run_function
ctx.input_tensors = list(args[:length])
ctx.input_params = list(args[length:])
with th.no_grad():
output_tensors = ctx.run_function(*ctx.input_tensors)
return output_tensors

@staticmethod
def backward(ctx, *output_grads):
ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
with th.enable_grad():
# Fixes a bug where the first op in run_function modifies the
# Tensor storage in place, which is not allowed for detach()'d
# Tensors.
shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
output_tensors = ctx.run_function(*shallow_copies)
input_grads = th.autograd.grad(
output_tensors,
ctx.input_tensors + ctx.input_params,
output_grads,
allow_unused=True,
)
del ctx.input_tensors
del ctx.input_params
del output_tensors
return (None, None) + input_grads
Loading

0 comments on commit 335dc49

Please sign in to comment.