Skip to content

Commit

Permalink
added random deform
Browse files Browse the repository at this point in the history
  • Loading branch information
liopeer committed May 12, 2024
1 parent 335dc49 commit c177869
Show file tree
Hide file tree
Showing 13 changed files with 2,571 additions and 24 deletions.
63 changes: 63 additions & 0 deletions check_dataset.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [],
"source": [
"from diffusion_models.utils.datasets import LumbarSpineDataset\n",
"\n",
"ds = LumbarSpineDataset()"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"163584"
]
},
"execution_count": 23,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(ds)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "liotorch",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
1,393 changes: 1,393 additions & 0 deletions create_dataset.ipynb

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions diffusion_models/models/diffusion_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,8 @@ def denoise_singlestep(
"""
self.model.eval()
with torch.no_grad():
t_enc = self.time_encoder.get_pos_encoding(t)
noise_pred = self.model(x, t_enc)
# t_enc = self.time_encoder.get_pos_encoding(t)
noise_pred = self.model(x, t/self.fwd_diff.timesteps)
alpha = self.fwd_diff.alphas[t][:, None, None, None]
alpha_hat = self.fwd_diff.alphas_dash[t][:, None, None, None]
beta = self.fwd_diff.betas[t][:, None, None, None]
Expand Down
Empty file.
57 changes: 57 additions & 0 deletions diffusion_models/spine_dataset/augmentations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
"""Module containing the augmentation class for random erosion and dilation."""
import numpy as np
import scipy.ndimage as ndimage


class RandomErosion:
"""Random erosion augmentation class."""

def __init__(
self, randomState: np.random.RandomState, alpha=0.66, beta=5
) -> None:
"""Initialize the random erosion augmentation class.
Randomly erodes/dilates the image with a probability of alpha and a maximum number of iterations of beta.
Args:
randomState (np.random.RandomState): randomstate object to use for random number generation
alpha (float, optional): Hyperparameter alpha, probability of doing augmentation. Defaults to 0.66.
beta (int, optional): Hyperparameter beta, maximum number of erosion/dilation iterations. Defaults to 5.
"""
self.alpha = alpha
self.beta = beta
self.randomState = randomState

def __call__(self, img_np: np.ndarray) -> np.ndarray:
"""Apply the augmentation to the image.
Args:
img_np (np.ndarray): image to augment
Returns:
np.ndarray: augmented image
"""
img_np = np.where(img_np != 0, 1, 0).astype(img_np.dtype)

for i in range(img_np.shape[1]):
do_augment = self.randomState.rand() < self.alpha

if do_augment:
do_erosion = self.randomState.rand() < 0.5

if do_erosion:
n_iter = self.randomState.randint(
1, self.beta
) # [1, beta)
img_np[:, i, :] = ndimage.binary_erosion(
img_np[:, i, :], iterations=n_iter
).astype(img_np.dtype)
else:
n_iter = self.randomState.randint(
1, self.beta
) # [1, beta)
img_np[:, i, :] = ndimage.binary_dilation(
img_np[:, i, :], iterations=n_iter
).astype(img_np.dtype)

return img_np
90 changes: 90 additions & 0 deletions diffusion_models/spine_dataset/base_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import torch
from torch.utils.data import Dataset
from abc import ABC, abstractmethod
from typing import Literal, Dict, Any, Optional
from jaxtyping import Float32, UInt, UInt64

class BaseDataset(Dataset, ABC):
"""Interface for Datasets in Spine Diffusion package.
This interface is currently not enforced, but any dataset implementation
should follow the guidelines outlined here, this is especially true for the
exact returns of the __getitem__ method.
"""
def __init__(
self,
resolution: int,
random_crop: bool,
crop_size: int,
mode: Literal["train","val","test"],
**kwargs
):
"""Constructor of BaseDataset.
Args:
resolution: determines base resolution of the dataset, i.e. a
dataset with an original size of 256 (in 3D) will be downsampled
to that resolution
random_crop:
"""
raise NotImplementedError

def __getitem__(self, idx: int) -> Dict[str, Any]:
"""__getitem__ method of BaseDataset.
Args:
idx: index of desired sample
Returns:
dictionary with keys and items as below (not all keys necessary)
.. code-block:: python
dict(
# ch corresponds to num_classes where applicable
sdf: Optional[
Float32[Tensor, "1 res res res"],
None
] = None,
occ: Optional[
UInt64[Tensor, "1 res res res"], # with unique values in range(2, num_classes+1)
] = None,
coords: Optional[Float32[Tensor, "num_points 3"], None] = None,
targets_occ: Optional[
UInt64[Tensor, "num_points"], # 2 or multi class with probabilities 1
Float32[Tensor, "num_points num_classes"], # 2 or multi class with probabilities in [0,1]
None
] = None,
targets_sdf: Optional[
Float32[Tensor, "num_points"],
None
]
loss_fn: Literal["crossentropylogits","mse"] = "crossentropylogits"
metadata: Optional[Any, None] = None
)
- "sdf" is full volume and should be normalized to [-1,1] range
- "sdf_target" is cropped volume, equally normalized, may be a TSDF
of the original data to enhance learning
- "occ_float" full volume occupancy as torch.float32, normalized to [0,1] range
- "occ_target" cropped binary/multi-class torch.long tensor
- "vox_coords" contains coords of voxel centers of "sdf_target" or "occ_target",
normalized to [-1,1] range (see torch.grid_sample(align_corners=True) for reference).
If random_crop is False, this is not needed and will default to
all voxel centers in the volume.
- "rand_coords" can be used for randomized sampling of coordinates instead
of voxel centers
- "rand_targets" can be used for interpolated SDF values
- "metadata" anything
- while the channels "ch" will usually be 1, it might be good for multi-class
problems to split classes between channels, should be float for
"""
raise NotImplementedError

def __len__(self):
raise NotImplementedError

@staticmethod
def check_output(output: Dict[str, Any]):
pass
18 changes: 18 additions & 0 deletions diffusion_models/spine_dataset/collations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import torch
from typing import List

def collate_fn(batch: List[dict]):
res = {key: [] for key in batch[0].keys()}
res["loss_fn"] = batch[0]["loss_fn"]
for sample in batch:
for key, elem in sample.items():
if isinstance(elem, torch.Tensor):
res[key].append(elem)
for key, elem in res.items():
if isinstance(elem, list):
res[key] = torch.stack(res[key], dim=0)
elif isinstance(elem, str):
assert key == "loss_fn"
else:
raise ValueError(f"{key}")
return res
50 changes: 50 additions & 0 deletions diffusion_models/spine_dataset/get_dataloader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from spine_diffusion.datasets.random_deformation_dataset import (
RandomDeformationDataset
)
from spine_diffusion.datasets.shapenet import ShapeNet_Dataset
from typing import Literal, Tuple
from torch.utils.data import DataLoader
from spine_diffusion.datasets.collations import collate_fn

def get_trainval_dataloaders(
dataset: Literal["spine", "shapenet"],
config: dict,
batch_size: int
) -> Tuple[DataLoader]:
if dataset == "spine":
train_ds = RandomDeformationDataset(
mode = "train",
**config
)
val_ds = RandomDeformationDataset(
mode = "val",
**config
)
elif dataset == "shapenet":
train_ds = ShapeNet_Dataset(
mode = "train",
**config
)
val_ds = ShapeNet_Dataset(
mode = "val",
**config
)
else:
raise ValueError("no such dataset")
train_dl = DataLoader(
train_ds,
batch_size = batch_size,
shuffle = True,
num_workers = batch_size,
collate_fn = collate_fn,
pin_memory=True
)
val_dl = DataLoader(
val_ds,
batch_size = batch_size,
shuffle = False,
num_workers = batch_size,
collate_fn = collate_fn,
pin_memory=True
)
return train_dl, val_dl
Loading

0 comments on commit c177869

Please sign in to comment.