From 34a4ef14ff1ca2d4e88db3298398968671a4e279 Mon Sep 17 00:00:00 2001 From: Lionel Peer Date: Sun, 14 Apr 2024 10:16:48 +0000 Subject: [PATCH] all warnings gone --- diffusion_models/__init__.py | 1 - diffusion_models/losses/kl_divergence.py | 4 +- diffusion_models/models/diffusion.py | 4 +- diffusion_models/models/multicoil.py | 2 +- .../models/repaint_unet/fp16_util.py | 245 ----- diffusion_models/models/repaint_unet/nn.py | 186 ---- diffusion_models/models/repaint_unet/unet.py | 912 ------------------ diffusion_models/models/sampler.py | 4 +- diffusion_models/models/unet.py | 2 +- diffusion_models/models/vae.py | 1 + diffusion_models/mri_forward/fft.py | 2 +- diffusion_models/utils/datasets.py | 2 +- diffusion_models/utils/trainer.py | 2 +- docs/source/conf.py | 7 +- packages.dot | 25 - requirements.txt | 3 +- 16 files changed, 19 insertions(+), 1383 deletions(-) delete mode 100644 diffusion_models/models/repaint_unet/fp16_util.py delete mode 100644 diffusion_models/models/repaint_unet/nn.py delete mode 100644 diffusion_models/models/repaint_unet/unet.py delete mode 100644 packages.dot diff --git a/diffusion_models/__init__.py b/diffusion_models/__init__.py index 06fbe7e..e69de29 100644 --- a/diffusion_models/__init__.py +++ b/diffusion_models/__init__.py @@ -1 +0,0 @@ -version = "0.0.1" \ No newline at end of file diff --git a/diffusion_models/losses/kl_divergence.py b/diffusion_models/losses/kl_divergence.py index 60ec0d6..fa6f828 100644 --- a/diffusion_models/losses/kl_divergence.py +++ b/diffusion_models/losses/kl_divergence.py @@ -10,7 +10,7 @@ def gaussian_kl( ) -> Float[Tensor, "1"]: """Calculate KL Divergence of 2 Gaussian distributions. - KL divergence between two univariate Gaussians, as derived in [1], with k=1 (dimensionality). + KL divergence between two univariate Gaussians, as derived in [1]_, with k=1 (dimensionality). Parameters ---------- @@ -42,7 +42,7 @@ def log_gaussian_kl( ) -> Float[Tensor, "1"]: """Calculate KL Divergence of 2 Gaussian distributions. - KL divergence between two univariate Gaussians, as derived in [1], with k=1 (dimensionality) and log variances. + KL divergence between two univariate Gaussians, as derived in [1]_, with k=1 (dimensionality) and log variances. Parameters ---------- diff --git a/diffusion_models/models/diffusion.py b/diffusion_models/models/diffusion.py index 18920a4..36f0452 100644 --- a/diffusion_models/models/diffusion.py +++ b/diffusion_models/models/diffusion.py @@ -2,9 +2,9 @@ from torch import nn, Tensor from jaxtyping import Float, Int64, Int from typing import Literal, Tuple, Union, List -from models.positional_encoding import PositionalEncoding +from diffusion_models.models.positional_encoding import PositionalEncoding import math -from models.unet import UNet +from diffusion_models.models.unet import UNet class ForwardDiffusion(nn.Module): """Class for forward diffusion process in DDPMs (denoising diffusion probabilistic models). diff --git a/diffusion_models/models/multicoil.py b/diffusion_models/models/multicoil.py index 31e1b13..beacdcd 100644 --- a/diffusion_models/models/multicoil.py +++ b/diffusion_models/models/multicoil.py @@ -21,7 +21,7 @@ def __init__(self, channel_factors: List[int]=(4, 8, 16, 32), kernel_size: int=3 This class takes every coil independently (treats them like a sub-fraction of a batch), increases the channel size massively (from 2 initial channels for complex k-space data) via several convolutional layers and then averages those channels over the coil dimension. Averaging is invariant to permutations of the input order, so the coil order - or the number of coils will not matter anymore. Inspiration was drawn from point cloud processing, see below. + or the number of coils will not matter anymore. Inspiration was drawn from point cloud processing [1]_, see below. .. [1] Qi et al., PointNet: Deep Learning on Point Sets for 3D Classification and Segmentation, 2017 diff --git a/diffusion_models/models/repaint_unet/fp16_util.py b/diffusion_models/models/repaint_unet/fp16_util.py deleted file mode 100644 index 0d885f2..0000000 --- a/diffusion_models/models/repaint_unet/fp16_util.py +++ /dev/null @@ -1,245 +0,0 @@ -# Copyright (c) 2022 Huawei Technologies Co., Ltd. -# Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode -# -# The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd. -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# This repository was forked from https://github.com/openai/guided-diffusion, which is under the MIT license - -""" -Helpers to train with 16-bit precision. -""" - -import numpy as np -import torch as th -import torch.nn as nn -from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors - - -INITIAL_LOG_LOSS_SCALE = 20.0 - - -def convert_module_to_f16(l): - """ - Convert primitive modules to float16. - """ - if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): - l.weight.data = l.weight.data.half() - if l.bias is not None: - l.bias.data = l.bias.data.half() - - -def convert_module_to_f32(l): - """ - Convert primitive modules to float32, undoing convert_module_to_f16(). - """ - if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): - l.weight.data = l.weight.data.float() - if l.bias is not None: - l.bias.data = l.bias.data.float() - - -def make_master_params(param_groups_and_shapes): - """ - Copy model parameters into a (differently-shaped) list of full-precision - parameters. - """ - master_params = [] - for param_group, shape in param_groups_and_shapes: - master_param = nn.Parameter( - _flatten_dense_tensors( - [param.detach().float() for (_, param) in param_group] - ).view(shape) - ) - master_param.requires_grad = True - master_params.append(master_param) - return master_params - - -def model_grads_to_master_grads(param_groups_and_shapes, master_params): - """ - Copy the gradients from the model parameters into the master parameters - from make_master_params(). - """ - for master_param, (param_group, shape) in zip( - master_params, param_groups_and_shapes - ): - master_param.grad = _flatten_dense_tensors( - [param_grad_or_zeros(param) for (_, param) in param_group] - ).view(shape) - - -def master_params_to_model_params(param_groups_and_shapes, master_params): - """ - Copy the master parameter data back into the model parameters. - """ - # Without copying to a list, if a generator is passed, this will - # silently not copy any parameters. - for master_param, (param_group, _) in zip(master_params, param_groups_and_shapes): - for (_, param), unflat_master_param in zip( - param_group, unflatten_master_params(param_group, master_param.view(-1)) - ): - param.detach().copy_(unflat_master_param) - - -def unflatten_master_params(param_group, master_param): - return _unflatten_dense_tensors(master_param, [param for (_, param) in param_group]) - - -def get_param_groups_and_shapes(named_model_params): - named_model_params = list(named_model_params) - scalar_vector_named_params = ( - [(n, p) for (n, p) in named_model_params if p.ndim <= 1], - (-1), - ) - matrix_named_params = ( - [(n, p) for (n, p) in named_model_params if p.ndim > 1], - (1, -1), - ) - return [scalar_vector_named_params, matrix_named_params] - - -def master_params_to_state_dict( - model, param_groups_and_shapes, master_params, use_fp16 -): - if use_fp16: - state_dict = model.state_dict() - for master_param, (param_group, _) in zip( - master_params, param_groups_and_shapes - ): - for (name, _), unflat_master_param in zip( - param_group, unflatten_master_params(param_group, master_param.view(-1)) - ): - assert name in state_dict - state_dict[name] = unflat_master_param - else: - state_dict = model.state_dict() - for i, (name, _value) in enumerate(model.named_parameters()): - assert name in state_dict - state_dict[name] = master_params[i] - return state_dict - - -def state_dict_to_master_params(model, state_dict, use_fp16): - if use_fp16: - named_model_params = [ - (name, state_dict[name]) for name, _ in model.named_parameters() - ] - param_groups_and_shapes = get_param_groups_and_shapes(named_model_params) - master_params = make_master_params(param_groups_and_shapes) - else: - master_params = [state_dict[name] for name, _ in model.named_parameters()] - return master_params - - -def zero_master_grads(master_params): - for param in master_params: - param.grad = None - - -def zero_grad(model_params): - for param in model_params: - # Taken from https://pytorch.org/docs/stable/_modules/torch/optim/optimizer.html#Optimizer.add_param_group - if param.grad is not None: - param.grad.detach_() - param.grad.zero_() - - -def param_grad_or_zeros(param): - if param.grad is not None: - return param.grad.data.detach() - else: - return th.zeros_like(param) - - -class MixedPrecisionTrainer: - def __init__( - self, - *, - model, - use_fp16=False, - fp16_scale_growth=1e-3, - initial_lg_loss_scale=INITIAL_LOG_LOSS_SCALE, - ): - self.model = model - self.use_fp16 = use_fp16 - self.fp16_scale_growth = fp16_scale_growth - - self.model_params = list(self.model.parameters()) - self.master_params = self.model_params - self.param_groups_and_shapes = None - self.lg_loss_scale = initial_lg_loss_scale - - if self.use_fp16: - self.param_groups_and_shapes = get_param_groups_and_shapes( - self.model.named_parameters() - ) - self.master_params = make_master_params(self.param_groups_and_shapes) - self.model.convert_to_fp16() - - def zero_grad(self): - zero_grad(self.model_params) - - def backward(self, loss: th.Tensor): - if self.use_fp16: - loss_scale = 2 ** self.lg_loss_scale - (loss * loss_scale).backward() - else: - loss.backward() - - def optimize(self, opt: th.optim.Optimizer): - if self.use_fp16: - return self._optimize_fp16(opt) - else: - return self._optimize_normal(opt) - - def _optimize_fp16(self, opt: th.optim.Optimizer): - model_grads_to_master_grads(self.param_groups_and_shapes, self.master_params) - grad_norm, param_norm = self._compute_norms(grad_scale=2 ** self.lg_loss_scale) - if check_overflow(grad_norm): - self.lg_loss_scale -= 1 - zero_master_grads(self.master_params) - return False - - for p in self.master_params: - p.grad.mul_(1.0 / (2 ** self.lg_loss_scale)) - opt.step() - zero_master_grads(self.master_params) - master_params_to_model_params(self.param_groups_and_shapes, self.master_params) - self.lg_loss_scale += self.fp16_scale_growth - return True - - def _optimize_normal(self, opt: th.optim.Optimizer): - grad_norm, param_norm = self._compute_norms() - opt.step() - return True - - def _compute_norms(self, grad_scale=1.0): - grad_norm = 0.0 - param_norm = 0.0 - for p in self.master_params: - with th.no_grad(): - param_norm += th.norm(p, p=2, dtype=th.float32).item() ** 2 - if p.grad is not None: - grad_norm += th.norm(p.grad, p=2, dtype=th.float32).item() ** 2 - return np.sqrt(grad_norm) / grad_scale, np.sqrt(param_norm) - - def master_params_to_state_dict(self, master_params): - return master_params_to_state_dict( - self.model, self.param_groups_and_shapes, master_params, self.use_fp16 - ) - - def state_dict_to_master_params(self, state_dict): - return state_dict_to_master_params(self.model, state_dict, self.use_fp16) - - -def check_overflow(value): - return (value == float("inf")) or (value == -float("inf")) or (value != value) \ No newline at end of file diff --git a/diffusion_models/models/repaint_unet/nn.py b/diffusion_models/models/repaint_unet/nn.py deleted file mode 100644 index 673d810..0000000 --- a/diffusion_models/models/repaint_unet/nn.py +++ /dev/null @@ -1,186 +0,0 @@ -# Copyright (c) 2022 Huawei Technologies Co., Ltd. -# Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode -# -# The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd. -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# This repository was forked from https://github.com/openai/guided-diffusion, which is under the MIT license - -""" -Various utilities for neural networks. -""" - -import math - -import torch as th -import torch.nn as nn - - -# PyTorch 1.7 has SiLU, but we support PyTorch 1.5. -class SiLU(nn.Module): - def forward(self, x): - return x * th.sigmoid(x) - - -class GroupNorm32(nn.GroupNorm): - def forward(self, x): - return super().forward(x.float()).type(x.dtype) - - -def conv_nd(dims, *args, **kwargs): - """ - Create a 1D, 2D, or 3D convolution module. - """ - if dims == 1: - return nn.Conv1d(*args, **kwargs) - elif dims == 2: - return nn.Conv2d(*args, **kwargs) - elif dims == 3: - return nn.Conv3d(*args, **kwargs) - raise ValueError(f"unsupported dimensions: {dims}") - - -def linear(*args, **kwargs): - """ - Create a linear module. - """ - return nn.Linear(*args, **kwargs) - - -def avg_pool_nd(dims, *args, **kwargs): - """ - Create a 1D, 2D, or 3D average pooling module. - """ - if dims == 1: - return nn.AvgPool1d(*args, **kwargs) - elif dims == 2: - return nn.AvgPool2d(*args, **kwargs) - elif dims == 3: - return nn.AvgPool3d(*args, **kwargs) - raise ValueError(f"unsupported dimensions: {dims}") - - -def update_ema(target_params, source_params, rate=0.99): - """ - Update target parameters to be closer to those of source parameters using - an exponential moving average. - - :param target_params: the target parameter sequence. - :param source_params: the source parameter sequence. - :param rate: the EMA rate (closer to 1 means slower). - """ - for targ, src in zip(target_params, source_params): - targ.detach().mul_(rate).add_(src, alpha=1 - rate) - - -def zero_module(module): - """ - Zero out the parameters of a module and return it. - """ - for p in module.parameters(): - p.detach().zero_() - return module - - -def scale_module(module, scale): - """ - Scale the parameters of a module and return it. - """ - for p in module.parameters(): - p.detach().mul_(scale) - return module - - -def mean_flat(tensor): - """ - Take the mean over all non-batch dimensions. - """ - return tensor.mean(dim=list(range(1, len(tensor.shape)))) - - -def normalization(channels): - """ - Make a standard normalization layer. - - :param channels: number of input channels. - :return: an nn.Module for normalization. - """ - return GroupNorm32(32, channels) - - -def timestep_embedding(timesteps, dim, max_period=10000): - """ - Create sinusoidal timestep embeddings. - - :param timesteps: a 1-D Tensor of N indices, one per batch element. - These may be fractional. - :param dim: the dimension of the output. - :param max_period: controls the minimum frequency of the embeddings. - :return: an [N x dim] Tensor of positional embeddings. - """ - half = dim // 2 - freqs = th.exp( - -math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half - ).to(device=timesteps.device) - args = timesteps[:, None].float() * freqs[None] - embedding = th.cat([th.cos(args), th.sin(args)], dim=-1) - if dim % 2: - embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1) - return embedding - - -def checkpoint(func, inputs, params, flag): - """ - Evaluate a function without caching intermediate activations, allowing for - reduced memory at the expense of extra compute in the backward pass. - - :param func: the function to evaluate. - :param inputs: the argument sequence to pass to `func`. - :param params: a sequence of parameters `func` depends on but does not - explicitly take as arguments. - :param flag: if False, disable gradient checkpointing. - """ - if flag: - args = tuple(inputs) + tuple(params) - return CheckpointFunction.apply(func, len(inputs), *args) - else: - return func(*inputs) - - -class CheckpointFunction(th.autograd.Function): - @staticmethod - def forward(ctx, run_function, length, *args): - ctx.run_function = run_function - ctx.input_tensors = list(args[:length]) - ctx.input_params = list(args[length:]) - with th.no_grad(): - output_tensors = ctx.run_function(*ctx.input_tensors) - return output_tensors - - @staticmethod - def backward(ctx, *output_grads): - ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] - with th.enable_grad(): - # Fixes a bug where the first op in run_function modifies the - # Tensor storage in place, which is not allowed for detach()'d - # Tensors. - shallow_copies = [x.view_as(x) for x in ctx.input_tensors] - output_tensors = ctx.run_function(*shallow_copies) - input_grads = th.autograd.grad( - output_tensors, - ctx.input_tensors + ctx.input_params, - output_grads, - allow_unused=True, - ) - del ctx.input_tensors - del ctx.input_params - del output_tensors - return (None, None) + input_grads diff --git a/diffusion_models/models/repaint_unet/unet.py b/diffusion_models/models/repaint_unet/unet.py deleted file mode 100644 index 4fa65b9..0000000 --- a/diffusion_models/models/repaint_unet/unet.py +++ /dev/null @@ -1,912 +0,0 @@ -# Copyright (c) 2022 Huawei Technologies Co., Ltd. -# Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode -# -# The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd. -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# This repository was forked from https://github.com/openai/guided-diffusion, which is under the MIT license - -from abc import abstractmethod - -import math - -from .fp16_util import convert_module_to_f16, convert_module_to_f32 -import torch as th -import torch.nn as nn -import torch.nn.functional as F - -from .nn import ( - checkpoint, - conv_nd, - linear, - avg_pool_nd, - zero_module, - normalization, - timestep_embedding, -) - - -class AttentionPool2d(nn.Module): - """ - Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py - """ - - def __init__( - self, - spacial_dim: int, - embed_dim: int, - num_heads_channels: int, - output_dim: int = None, - ): - super().__init__() - self.positional_embedding = nn.Parameter( - th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5 - ) - self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1) - self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1) - self.num_heads = embed_dim // num_heads_channels - self.attention = QKVAttention(self.num_heads) - - def forward(self, x, **kwargs): - b, c, *_spatial = x.shape - x = x.reshape(b, c, -1) # NC(HW) - x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1) - x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1) - x = self.qkv_proj(x) - x = self.attention(x) - x = self.c_proj(x) - return x[:, :, 0] - - -class TimestepBlock(nn.Module): - """ - Any module where forward() takes timestep embeddings as a second argument. - """ - - @abstractmethod - def forward(self, x, emb): - """ - Apply the module to `x` given `emb` timestep embeddings. - """ - - -class TimestepEmbedSequential(nn.Sequential, TimestepBlock): - """ - A sequential module that passes timestep embeddings to the children that - support it as an extra input. - """ - - def forward(self, x, emb): - for layer in self: - if isinstance(layer, TimestepBlock): - x = layer(x, emb) - else: - x = layer(x) - return x - - -class Upsample(nn.Module): - """ - An upsampling layer with an optional convolution. - - :param channels: channels in the inputs and outputs. - :param use_conv: a bool determining if a convolution is applied. - :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then - upsampling occurs in the inner-two dimensions. - """ - - def __init__(self, channels, use_conv, dims=2, out_channels=None): - super().__init__() - self.channels = channels - self.out_channels = out_channels or channels - self.use_conv = use_conv - self.dims = dims - if use_conv: - self.conv = conv_nd(dims, self.channels, - self.out_channels, 3, padding=1) - - def forward(self, x): - assert x.shape[1] == self.channels - if self.dims == 3: - x = F.interpolate( - x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest" - ) - else: - x = F.interpolate(x, scale_factor=2, mode="nearest") - if self.use_conv: - x = self.conv(x) - return x - - -class Downsample(nn.Module): - """ - A downsampling layer with an optional convolution. - - :param channels: channels in the inputs and outputs. - :param use_conv: a bool determining if a convolution is applied. - :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then - downsampling occurs in the inner-two dimensions. - """ - - def __init__(self, channels, use_conv, dims=2, out_channels=None): - super().__init__() - self.channels = channels - self.out_channels = out_channels or channels - self.use_conv = use_conv - self.dims = dims - stride = 2 if dims != 3 else (1, 2, 2) - if use_conv: - self.op = conv_nd( - dims, self.channels, self.out_channels, 3, stride=stride, padding=1 - ) - else: - assert self.channels == self.out_channels - self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) - - def forward(self, x): - assert x.shape[1] == self.channels - return self.op(x) - - -class ResBlock(TimestepBlock): - """ - A residual block that can optionally change the number of channels. - - :param channels: the number of input channels. - :param emb_channels: the number of timestep embedding channels. - :param dropout: the rate of dropout. - :param out_channels: if specified, the number of out channels. - :param use_conv: if True and out_channels is specified, use a spatial - convolution instead of a smaller 1x1 convolution to change the - channels in the skip connection. - :param dims: determines if the signal is 1D, 2D, or 3D. - :param use_checkpoint: if True, use gradient checkpointing on this module. - :param up: if True, use this block for upsampling. - :param down: if True, use this block for downsampling. - """ - - def __init__( - self, - channels, - emb_channels, - dropout, - out_channels=None, - use_conv=False, - use_scale_shift_norm=False, - dims=2, - use_checkpoint=False, - up=False, - down=False, - ): - super().__init__() - self.channels = channels - self.emb_channels = emb_channels - self.dropout = dropout - self.out_channels = out_channels or channels - self.use_conv = use_conv - self.use_checkpoint = use_checkpoint - self.use_scale_shift_norm = use_scale_shift_norm - - self.in_layers = nn.Sequential( - normalization(channels), - nn.SiLU(), - conv_nd(dims, channels, self.out_channels, 3, padding=1), - ) - - self.updown = up or down - - if up: - self.h_upd = Upsample(channels, False, dims) - self.x_upd = Upsample(channels, False, dims) - elif down: - self.h_upd = Downsample(channels, False, dims) - self.x_upd = Downsample(channels, False, dims) - else: - self.h_upd = self.x_upd = nn.Identity() - - self.emb_layers = nn.Sequential( - nn.SiLU(), - linear( - emb_channels, - 2 * self.out_channels if use_scale_shift_norm else self.out_channels, - ), - ) - self.out_layers = nn.Sequential( - normalization(self.out_channels), - nn.SiLU(), - nn.Dropout(p=dropout), - zero_module( - conv_nd(dims, self.out_channels, - self.out_channels, 3, padding=1) - ), - ) - - if self.out_channels == channels: - self.skip_connection = nn.Identity() - elif use_conv: - self.skip_connection = conv_nd( - dims, channels, self.out_channels, 3, padding=1 - ) - else: - self.skip_connection = conv_nd( - dims, channels, self.out_channels, 1) - - def forward(self, x, emb): - """ - Apply the block to a Tensor, conditioned on a timestep embedding. - - :param x: an [N x C x ...] Tensor of features. - :param emb: an [N x emb_channels] Tensor of timestep embeddings. - :return: an [N x C x ...] Tensor of outputs. - """ - return checkpoint( - self._forward, (x, emb), self.parameters(), self.use_checkpoint - ) - - def _forward(self, x, emb): - if self.updown: - in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] - h = in_rest(x) - h = self.h_upd(h) - x = self.x_upd(x) - h = in_conv(h) - else: - h = self.in_layers(x) - emb_out = self.emb_layers(emb).type(h.dtype) - while len(emb_out.shape) < len(h.shape): - emb_out = emb_out[..., None] - if self.use_scale_shift_norm: - out_norm, out_rest = self.out_layers[0], self.out_layers[1:] - scale, shift = th.chunk(emb_out, 2, dim=1) - h = out_norm(h) * (1 + scale) + shift - h = out_rest(h) - else: - h = h + emb_out - h = self.out_layers(h) - return self.skip_connection(x) + h - - -class AttentionBlock(nn.Module): - """ - An attention block that allows spatial positions to attend to each other. - - Originally ported from here, but adapted to the N-d case. - https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. - """ - - def __init__( - self, - channels, - num_heads=1, - num_head_channels=-1, - use_checkpoint=False, - use_new_attention_order=False, - ): - super().__init__() - self.channels = channels - if num_head_channels == -1: - self.num_heads = num_heads - else: - assert ( - channels % num_head_channels == 0 - ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" - self.num_heads = channels // num_head_channels - self.use_checkpoint = use_checkpoint - self.norm = normalization(channels) - self.qkv = conv_nd(1, channels, channels * 3, 1) - if use_new_attention_order: - # split qkv before split heads - self.attention = QKVAttention(self.num_heads) - else: - # split heads before split qkv - self.attention = QKVAttentionLegacy(self.num_heads) - - self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) - - def forward(self, x): - return checkpoint(self._forward, (x,), self.parameters(), True) - - def _forward(self, x): - b, c, *spatial = x.shape - - # Both spacial dimensions to a single verctor - x = x.reshape(b, c, -1) - - # Predict core key values using a 1x1 convolusion (h*w -> 3*h*2) - qkv = self.qkv(self.norm(x)) - - h = self.attention(qkv) - - h = self.proj_out(h) - return (x + h).reshape(b, c, *spatial) - - -class QKVAttentionLegacy(nn.Module): - """ - A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping - """ - - def __init__(self, n_heads): - super().__init__() - self.n_heads = n_heads - - def forward(self, qkv): - """ - Apply QKV attention. - - :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. - :return: an [N x (H * C) x T] tensor after attention. - """ - bs, width, length = qkv.shape - assert width % (3 * self.n_heads) == 0 - ch = width // (3 * self.n_heads) - q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, - length).split(ch, dim=1) - scale = 1 / math.sqrt(math.sqrt(ch)) - weight = th.einsum( - "bct,bcs->bts", q * scale, k * scale - ) # More stable with f16 than dividing afterwards - weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) - a = th.einsum("bts,bcs->bct", weight, v) - return a.reshape(bs, -1, length) - - @staticmethod - def count_flops(model, _x, y): - return count_flops_attn(model, _x, y) - - -class QKVAttention(nn.Module): - """ - A module which performs QKV attention and splits in a different order. - """ - - def __init__(self, n_heads): - super().__init__() - self.n_heads = n_heads - - def forward(self, qkv): - """ - Apply QKV attention. - - :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. - :return: an [N x (H * C) x T] tensor after attention. - """ - bs, width, length = qkv.shape - assert width % (3 * self.n_heads) == 0 - ch = width // (3 * self.n_heads) - q, k, v = qkv.chunk(3, dim=1) - scale = 1 / math.sqrt(math.sqrt(ch)) - weight = th.einsum( - "bct,bcs->bts", - (q * scale).view(bs * self.n_heads, ch, length), - (k * scale).view(bs * self.n_heads, ch, length), - ) # More stable with f16 than dividing afterwards - weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) - a = th.einsum("bts,bcs->bct", weight, - v.reshape(bs * self.n_heads, ch, length)) - return a.reshape(bs, -1, length) - - @staticmethod - def count_flops(model, _x, y): - return count_flops_attn(model, _x, y) - - -class UNetModel(nn.Module): - """ - The full UNet model with attention and timestep embedding. - - :param in_channels: channels in the input Tensor. - :param model_channels: base channel count for the model. - :param out_channels: channels in the output Tensor. - :param num_res_blocks: number of residual blocks per downsample. - :param attention_resolutions: a collection of downsample rates at which - attention will take place. May be a set, list, or tuple. - For example, if this contains 4, then at 4x downsampling, attention - will be used. - :param dropout: the dropout probability. - :param channel_mult: channel multiplier for each level of the UNet. - :param conv_resample: if True, use learned convolutions for upsampling and - downsampling. - :param dims: determines if the signal is 1D, 2D, or 3D. - :param num_classes: if specified (as an int), then this model will be - class-conditional with `num_classes` classes. - :param use_checkpoint: use gradient checkpointing to reduce memory usage. - :param num_heads: the number of attention heads in each attention layer. - :param num_heads_channels: if specified, ignore num_heads and instead use - a fixed channel width per attention head. - :param num_heads_upsample: works with num_heads to set a different number - of heads for upsampling. Deprecated. - :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. - :param resblock_updown: use residual blocks for up/downsampling. - :param use_new_attention_order: use a different attention pattern for potentially - increased efficiency. - """ - - def __init__( - self, - image_size, - in_channels, - num_encoding_blocks, - model_channels, - out_channels, - num_res_blocks, - attention_resolutions, - time_embed_dim, - dropout=0, - conv_resample=True, - dims=2, - num_classes=None, - use_checkpoint=False, - use_fp16=False, - num_heads=1, - num_head_channels=-1, - num_heads_upsample=-1, - use_scale_shift_norm=False, - resblock_updown=False, - use_new_attention_order=False, - conf=None - ): - super().__init__() - - if num_heads_upsample == -1: - num_heads_upsample = num_heads - - self.image_size = image_size - self.in_channels = in_channels - self.model_channels = model_channels - self.out_channels = out_channels - self.num_res_blocks = num_res_blocks - self.attention_resolutions = attention_resolutions - self.time_embed_dim = time_embed_dim - self.dropout = dropout - self.channel_mult = tuple(2**i for i in range(num_encoding_blocks)) - self.conv_resample = conv_resample - self.num_classes = num_classes - self.use_checkpoint = use_checkpoint - self.dtype = th.float16 if use_fp16 else th.float32 - self.num_heads = num_heads - self.num_head_channels = num_head_channels - self.num_heads_upsample = num_heads_upsample - self.conf = conf - - self.time_embed = nn.Sequential( - linear(time_embed_dim, time_embed_dim), - nn.SiLU(), - linear(time_embed_dim, time_embed_dim), - ) - - if self.num_classes is not None: - self.label_emb = nn.Embedding(num_classes, time_embed_dim) - - ch = input_ch = int(self.channel_mult[0] * model_channels) - self.input_blocks = nn.ModuleList( - [TimestepEmbedSequential( - conv_nd(dims, in_channels, ch, 3, padding=1))] - ) - self._feature_size = ch - input_block_chans = [ch] - ds = 1 - for level, mult in enumerate(self.channel_mult): - for _ in range(num_res_blocks): - layers = [ - ResBlock( - ch, - time_embed_dim, - dropout, - out_channels=int(mult * model_channels), - dims=dims, - use_checkpoint=use_checkpoint, - use_scale_shift_norm=use_scale_shift_norm, - ) - ] - ch = int(mult * model_channels) - if ds in attention_resolutions: - layers.append( - AttentionBlock( - ch, - use_checkpoint=use_checkpoint, - num_heads=num_heads, - num_head_channels=num_head_channels, - use_new_attention_order=use_new_attention_order, - ) - ) - self.input_blocks.append(TimestepEmbedSequential(*layers)) - self._feature_size += ch - input_block_chans.append(ch) - if level != len(self.channel_mult) - 1: - out_ch = ch - self.input_blocks.append( - TimestepEmbedSequential( - ResBlock( - ch, - time_embed_dim, - dropout, - out_channels=out_ch, - dims=dims, - use_checkpoint=use_checkpoint, - use_scale_shift_norm=use_scale_shift_norm, - down=True, - ) - if resblock_updown - else Downsample( - ch, conv_resample, dims=dims, out_channels=out_ch - ) - ) - ) - ch = out_ch - input_block_chans.append(ch) - ds *= 2 - self._feature_size += ch - - self.middle_block = TimestepEmbedSequential( - ResBlock( - ch, - time_embed_dim, - dropout, - dims=dims, - use_checkpoint=use_checkpoint, - use_scale_shift_norm=use_scale_shift_norm, - ), - AttentionBlock( - ch, - use_checkpoint=use_checkpoint, - num_heads=num_heads, - num_head_channels=num_head_channels, - use_new_attention_order=use_new_attention_order, - ), - ResBlock( - ch, - time_embed_dim, - dropout, - dims=dims, - use_checkpoint=use_checkpoint, - use_scale_shift_norm=use_scale_shift_norm, - ), - ) - self._feature_size += ch - - self.output_blocks = nn.ModuleList([]) - for level, mult in list(enumerate(self.channel_mult))[::-1]: - for i in range(num_res_blocks + 1): - ich = input_block_chans.pop() - layers = [ - ResBlock( - ch + ich, - time_embed_dim, - dropout, - out_channels=int(model_channels * mult), - dims=dims, - use_checkpoint=use_checkpoint, - use_scale_shift_norm=use_scale_shift_norm, - ) - ] - ch = int(model_channels * mult) - if ds in attention_resolutions: - layers.append( - AttentionBlock( - ch, - use_checkpoint=use_checkpoint, - num_heads=num_heads_upsample, - num_head_channels=num_head_channels, - use_new_attention_order=use_new_attention_order, - ) - ) - if level and i == num_res_blocks: - out_ch = ch - layers.append( - ResBlock( - ch, - time_embed_dim, - dropout, - out_channels=out_ch, - dims=dims, - use_checkpoint=use_checkpoint, - use_scale_shift_norm=use_scale_shift_norm, - up=True, - ) - if resblock_updown - else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch) - ) - ds //= 2 - self.output_blocks.append(TimestepEmbedSequential(*layers)) - self._feature_size += ch - - self.out = nn.Sequential( - normalization(ch), - nn.SiLU(), - zero_module(conv_nd(dims, input_ch, out_channels, 3, padding=1)), - ) - - def convert_to_fp16(self): - """ - Convert the torso of the model to float16. - """ - self.input_blocks.apply(convert_module_to_f16) - self.middle_block.apply(convert_module_to_f16) - self.output_blocks.apply(convert_module_to_f16) - - def convert_to_fp32(self): - """ - Convert the torso of the model to float32. - """ - self.input_blocks.apply(convert_module_to_f32) - self.middle_block.apply(convert_module_to_f32) - self.output_blocks.apply(convert_module_to_f32) - - def forward(self, x, timesteps, y=None, gt=None, **kwargs): - """ - Apply the model to an input batch. - - :param x: an [N x C x ...] Tensor of inputs. - :param timesteps: a 1-D batch of timesteps. - :param y: an [N] Tensor of labels, if class-conditional. - :return: an [N x C x ...] Tensor of outputs. - """ - - # if timesteps[0].item() > self.conf.diffusion_steps: - # raise RuntimeError("timesteps larger than diffusion steps.", - # timesteps[0].item(), self.conf.diffusion_steps) - - # if self.conf.use_value_logger: - # self.conf.value_logger.add_to_list( - # 'model_time', timesteps[0].item()) - - hs = [] - emb = self.time_embed(timesteps) - - if self.num_classes is not None: - assert y.shape == (x.shape[0],) - emb = emb + self.label_emb(y) - - h = x.type(self.dtype) - for module in self.input_blocks: - h = module(h, emb) - hs.append(h) - h = self.middle_block(h, emb) - for module in self.output_blocks: - h = th.cat([h, hs.pop()], dim=1) - h = module(h, emb) - h = h.type(x.dtype) - return self.out(h) - - -class SuperResModel(UNetModel): - """ - A UNetModel that performs super-resolution. - - Expects an extra kwarg `low_res` to condition on a low-resolution image. - """ - - def __init__(self, image_size, in_channels, *args, **kwargs): - super().__init__(image_size, in_channels * 2, *args, **kwargs) - - def forward(self, x, timesteps, low_res=None, **kwargs): - _, _, new_height, new_width = x.shape - upsampled = F.interpolate( - low_res, (new_height, new_width), mode="bilinear") - x = th.cat([x, upsampled], dim=1) - return super().forward(x, timesteps, **kwargs) - - -class EncoderUNetModel(nn.Module): - """ - The half UNet model with attention and timestep embedding. - - For usage, see UNet. - """ - - def __init__( - self, - image_size, - in_channels, - model_channels, - out_channels, - num_res_blocks, - attention_resolutions, - dropout=0, - channel_mult=(1, 2, 4, 8), - conv_resample=True, - dims=2, - use_checkpoint=False, - use_fp16=False, - num_heads=1, - num_head_channels=-1, - num_heads_upsample=-1, - use_scale_shift_norm=False, - resblock_updown=False, - use_new_attention_order=False, - pool="adaptive", - ): - super().__init__() - - if num_heads_upsample == -1: - num_heads_upsample = num_heads - - self.in_channels = in_channels - self.model_channels = model_channels - self.out_channels = out_channels - self.num_res_blocks = num_res_blocks - self.attention_resolutions = attention_resolutions - self.dropout = dropout - self.channel_mult = channel_mult - self.conv_resample = conv_resample - self.use_checkpoint = use_checkpoint - self.dtype = th.float16 if use_fp16 else th.float32 - self.num_heads = num_heads - self.num_head_channels = num_head_channels - self.num_heads_upsample = num_heads_upsample - - time_embed_dim = model_channels * 4 - self.time_embed = nn.Sequential( - linear(model_channels, time_embed_dim), - nn.SiLU(), - linear(time_embed_dim, time_embed_dim), - ) - - ch = int(channel_mult[0] * model_channels) - self.input_blocks = nn.ModuleList( - [TimestepEmbedSequential( - conv_nd(dims, in_channels, ch, 3, padding=1))] - ) - self._feature_size = ch - input_block_chans = [ch] - ds = 1 - for level, mult in enumerate(channel_mult): - for _ in range(num_res_blocks): - layers = [ - ResBlock( - ch, - time_embed_dim, - dropout, - out_channels=int(mult * model_channels), - dims=dims, - use_checkpoint=use_checkpoint, - use_scale_shift_norm=use_scale_shift_norm, - ) - ] - ch = int(mult * model_channels) - if ds in attention_resolutions: - layers.append( - AttentionBlock( - ch, - use_checkpoint=use_checkpoint, - num_heads=num_heads, - num_head_channels=num_head_channels, - use_new_attention_order=use_new_attention_order, - ) - ) - self.input_blocks.append(TimestepEmbedSequential(*layers)) - self._feature_size += ch - input_block_chans.append(ch) - if level != len(channel_mult) - 1: - out_ch = ch - self.input_blocks.append( - TimestepEmbedSequential( - ResBlock( - ch, - time_embed_dim, - dropout, - out_channels=out_ch, - dims=dims, - use_checkpoint=use_checkpoint, - use_scale_shift_norm=use_scale_shift_norm, - down=True, - ) - if resblock_updown - else Downsample( - ch, conv_resample, dims=dims, out_channels=out_ch - ) - ) - ) - ch = out_ch - input_block_chans.append(ch) - ds *= 2 - self._feature_size += ch - - self.middle_block = TimestepEmbedSequential( - ResBlock( - ch, - time_embed_dim, - dropout, - dims=dims, - use_checkpoint=use_checkpoint, - use_scale_shift_norm=use_scale_shift_norm, - ), - AttentionBlock( - ch, - use_checkpoint=use_checkpoint, - num_heads=num_heads, - num_head_channels=num_head_channels, - use_new_attention_order=use_new_attention_order, - ), - ResBlock( - ch, - time_embed_dim, - dropout, - dims=dims, - use_checkpoint=use_checkpoint, - use_scale_shift_norm=use_scale_shift_norm, - ), - ) - self._feature_size += ch - self.pool = pool - if pool == "adaptive": - self.out = nn.Sequential( - normalization(ch), - nn.SiLU(), - nn.AdaptiveAvgPool2d((1, 1)), - zero_module(conv_nd(dims, ch, out_channels, 1)), - nn.Flatten(), - ) - elif pool == "attention": - assert num_head_channels != -1 - self.out = nn.Sequential( - normalization(ch), - nn.SiLU(), - AttentionPool2d( - (image_size // ds), ch, num_head_channels, out_channels - ), - ) - elif pool == "spatial": - self.out = nn.Sequential( - nn.Linear(self._feature_size, 2048), - nn.ReLU(), - nn.Linear(2048, self.out_channels), - ) - elif pool == "spatial_v2": - self.out = nn.Sequential( - nn.Linear(self._feature_size, 2048), - normalization(2048), - nn.SiLU(), - nn.Linear(2048, self.out_channels), - ) - else: - raise NotImplementedError(f"Unexpected {pool} pooling") - - def convert_to_fp16(self): - """ - Convert the torso of the model to float16. - """ - self.input_blocks.apply(convert_module_to_f16) - self.middle_block.apply(convert_module_to_f16) - - def convert_to_fp32(self): - """ - Convert the torso of the model to float32. - """ - self.input_blocks.apply(convert_module_to_f32) - self.middle_block.apply(convert_module_to_f32) - - def forward(self, x, timesteps): - """ - Apply the model to an input batch. - - :param x: an [N x C x ...] Tensor of inputs. - :param timesteps: a 1-D batch of timesteps. - :return: an [N x K] Tensor of outputs. - """ - emb = self.time_embed(timestep_embedding( - timesteps, self.model_channels)) - - results = [] - h = x.type(self.dtype) - for module in self.input_blocks: - h = module(h, emb) - if self.pool.startswith("spatial"): - results.append(h.type(x.dtype).mean(dim=(2, 3))) - h = self.middle_block(h, emb) - if self.pool.startswith("spatial"): - results.append(h.type(x.dtype).mean(dim=(2, 3))) - h = th.cat(results, axis=-1) - return self.out(h) - else: - h = h.type(x.dtype) - return self.out(h) diff --git a/diffusion_models/models/sampler.py b/diffusion_models/models/sampler.py index 0dc2426..83eaac9 100644 --- a/diffusion_models/models/sampler.py +++ b/diffusion_models/models/sampler.py @@ -5,8 +5,8 @@ from torch import nn from jaxtyping import Float, Bool from typing import Callable, Literal, Any, Tuple, Union -from models.diffusion import DiffusionModel -from utils.helpers import bytes_to_gb +from diffusion_models.models.diffusion import DiffusionModel +from diffusion_models.utils.helpers import bytes_to_gb from torch.fft import fftn, ifftn, fftshift, ifftshift from tqdm import tqdm import torchvision diff --git a/diffusion_models/models/unet.py b/diffusion_models/models/unet.py index 518484a..b8a2168 100644 --- a/diffusion_models/models/unet.py +++ b/diffusion_models/models/unet.py @@ -199,7 +199,7 @@ def forward( return x class UNet(nn.Module): - """Implementation of UNet architecture, close to original paper. + """Implementation of UNet architecture, close to original paper. [1]_ Things that are different ------------------------- diff --git a/diffusion_models/models/vae.py b/diffusion_models/models/vae.py index e082603..d3f4fc2 100644 --- a/diffusion_models/models/vae.py +++ b/diffusion_models/models/vae.py @@ -34,6 +34,7 @@ class ResNet18Encoder(nn.Module): b. first residual block increases channels to 128, halves size with stride 2, second is standard c. like b., but to 256 channels d. like b., but to 512 channels + Output of residual blocks has size 7x7 with 512 channels. """ def __init__(self, in_channels: int, hidden_dim: int=256) -> None: diff --git a/diffusion_models/mri_forward/fft.py b/diffusion_models/mri_forward/fft.py index 9430f93..f29ba39 100644 --- a/diffusion_models/mri_forward/fft.py +++ b/diffusion_models/mri_forward/fft.py @@ -3,7 +3,7 @@ from jaxtyping import Float, Complex from torch import Tensor import torch -from utils.helpers import complex_to_2channelfloat +from diffusion_models.utils.helpers import complex_to_2channelfloat def to_kspace( x: Union[ diff --git a/diffusion_models/utils/datasets.py b/diffusion_models/utils/datasets.py index 4dd488d..0e2ff34 100644 --- a/diffusion_models/utils/datasets.py +++ b/diffusion_models/utils/datasets.py @@ -10,7 +10,7 @@ import torch import h5py from torch.fft import ifft2, fft2, fftshift, ifftshift, fftn, ifftn -from utils.helpers import complex_to_2channelfloat +from diffusion_models.utils.helpers import complex_to_2channelfloat class Cifar10Dataset(CIFAR10): def __init__(self, root: str, train: bool = True, transform: Callable[..., Any] | None = None, target_transform: Callable[..., Any] | None = None, download: bool = False) -> None: diff --git a/diffusion_models/utils/trainer.py b/diffusion_models/utils/trainer.py index db87855..bfe1c30 100644 --- a/diffusion_models/utils/trainer.py +++ b/diffusion_models/utils/trainer.py @@ -16,7 +16,7 @@ import torchvision from math import isqrt from jaxtyping import Float -from utils.helpers import bytes_to_gb +from diffusion_models.utils.helpers import bytes_to_gb from torch.optim.lr_scheduler import CosineAnnealingLR, CosineAnnealingWarmRestarts, StepLR from torch.fft import ifftn diff --git a/docs/source/conf.py b/docs/source/conf.py index dfd48cd..363a742 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -1,7 +1,6 @@ import os import sys sys.path.append(os.path.abspath("../..")) -import diffusion_models # Configuration file for the Sphinx documentation builder. # @@ -33,7 +32,11 @@ "numpy", "wandb", "torchvision", - "h5py" + "h5py", + "tqdm" + "time", + "typing", + "math" ] templates_path = ['_templates'] diff --git a/packages.dot b/packages.dot deleted file mode 100644 index d9a1c62..0000000 --- a/packages.dot +++ /dev/null @@ -1,25 +0,0 @@ -digraph "packages" { -rankdir=BT -charset="utf-8" -"diffusion_models" [color="black", label=, shape="box", style="solid"]; -"diffusion_models.losses" [color="black", label=, shape="box", style="solid"]; -"diffusion_models.losses.kl_divergence" [color="black", label=, shape="box", style="solid"]; -"diffusion_models.models" [color="black", label=, shape="box", style="solid"]; -"diffusion_models.models.diffusion" [color="black", label=, shape="box", style="solid"]; -"diffusion_models.models.mnist_enc" [color="black", label=, shape="box", style="solid"]; -"diffusion_models.models.multicoil" [color="black", label=, shape="box", style="solid"]; -"diffusion_models.models.positional_encoding" [color="black", label=, shape="box", style="solid"]; -"diffusion_models.models.sampler" [color="black", label=, shape="box", style="solid"]; -"diffusion_models.models.unet" [color="black", label=, shape="box", style="solid"]; -"diffusion_models.models.vae" [color="black", label=, shape="box", style="solid"]; -"diffusion_models.mri_forward" [color="black", label=, shape="box", style="solid"]; -"diffusion_models.mri_forward.fft" [color="black", label=, shape="box", style="solid"]; -"diffusion_models.mri_forward.noise" [color="black", label=, shape="box", style="solid"]; -"diffusion_models.mri_forward.undersampling_mask" [color="black", label=, shape="box", style="solid"]; -"diffusion_models.utils" [color="black", label=, shape="box", style="solid"]; -"diffusion_models.utils.datasets" [color="black", label=, shape="box", style="solid"]; -"diffusion_models.utils.helpers" [color="black", label=, shape="box", style="solid"]; -"diffusion_models.utils.mp_setup" [color="black", label=, shape="box", style="solid"]; -"diffusion_models.utils.trainer" [color="black", label=, shape="box", style="solid"]; -"diffusion_models.models.sampler" -> "diffusion_models.models.diffusion" [arrowhead="open", arrowtail="none"]; -} diff --git a/requirements.txt b/requirements.txt index 75edf1a..9099e32 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,4 +3,5 @@ lightning==2.2.2 jupyter==1.0.0 wandb==0.16.6 h5py==3.11.0 -furo==2024.1.29 \ No newline at end of file +furo==2024.1.29 +hydra-core==1.3.2 \ No newline at end of file