Skip to content

Commit c177869

Browse files
committed
added random deform
1 parent 335dc49 commit c177869

13 files changed

+2571
-24
lines changed

check_dataset.ipynb

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": 22,
6+
"metadata": {},
7+
"outputs": [],
8+
"source": [
9+
"from diffusion_models.utils.datasets import LumbarSpineDataset\n",
10+
"\n",
11+
"ds = LumbarSpineDataset()"
12+
]
13+
},
14+
{
15+
"cell_type": "code",
16+
"execution_count": 23,
17+
"metadata": {},
18+
"outputs": [
19+
{
20+
"data": {
21+
"text/plain": [
22+
"163584"
23+
]
24+
},
25+
"execution_count": 23,
26+
"metadata": {},
27+
"output_type": "execute_result"
28+
}
29+
],
30+
"source": [
31+
"len(ds)"
32+
]
33+
},
34+
{
35+
"cell_type": "code",
36+
"execution_count": null,
37+
"metadata": {},
38+
"outputs": [],
39+
"source": []
40+
}
41+
],
42+
"metadata": {
43+
"kernelspec": {
44+
"display_name": "liotorch",
45+
"language": "python",
46+
"name": "python3"
47+
},
48+
"language_info": {
49+
"codemirror_mode": {
50+
"name": "ipython",
51+
"version": 3
52+
},
53+
"file_extension": ".py",
54+
"mimetype": "text/x-python",
55+
"name": "python",
56+
"nbconvert_exporter": "python",
57+
"pygments_lexer": "ipython3",
58+
"version": "3.11.5"
59+
}
60+
},
61+
"nbformat": 4,
62+
"nbformat_minor": 2
63+
}

create_dataset.ipynb

Lines changed: 1393 additions & 0 deletions
Large diffs are not rendered by default.

diffusion_models/models/diffusion_openai.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,8 @@ def denoise_singlestep(
9090
"""
9191
self.model.eval()
9292
with torch.no_grad():
93-
t_enc = self.time_encoder.get_pos_encoding(t)
94-
noise_pred = self.model(x, t_enc)
93+
# t_enc = self.time_encoder.get_pos_encoding(t)
94+
noise_pred = self.model(x, t/self.fwd_diff.timesteps)
9595
alpha = self.fwd_diff.alphas[t][:, None, None, None]
9696
alpha_hat = self.fwd_diff.alphas_dash[t][:, None, None, None]
9797
beta = self.fwd_diff.betas[t][:, None, None, None]

diffusion_models/spine_dataset/__init__.py

Whitespace-only changes.
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
"""Module containing the augmentation class for random erosion and dilation."""
2+
import numpy as np
3+
import scipy.ndimage as ndimage
4+
5+
6+
class RandomErosion:
7+
"""Random erosion augmentation class."""
8+
9+
def __init__(
10+
self, randomState: np.random.RandomState, alpha=0.66, beta=5
11+
) -> None:
12+
"""Initialize the random erosion augmentation class.
13+
14+
Randomly erodes/dilates the image with a probability of alpha and a maximum number of iterations of beta.
15+
16+
Args:
17+
randomState (np.random.RandomState): randomstate object to use for random number generation
18+
alpha (float, optional): Hyperparameter alpha, probability of doing augmentation. Defaults to 0.66.
19+
beta (int, optional): Hyperparameter beta, maximum number of erosion/dilation iterations. Defaults to 5.
20+
"""
21+
self.alpha = alpha
22+
self.beta = beta
23+
self.randomState = randomState
24+
25+
def __call__(self, img_np: np.ndarray) -> np.ndarray:
26+
"""Apply the augmentation to the image.
27+
28+
Args:
29+
img_np (np.ndarray): image to augment
30+
31+
Returns:
32+
np.ndarray: augmented image
33+
"""
34+
img_np = np.where(img_np != 0, 1, 0).astype(img_np.dtype)
35+
36+
for i in range(img_np.shape[1]):
37+
do_augment = self.randomState.rand() < self.alpha
38+
39+
if do_augment:
40+
do_erosion = self.randomState.rand() < 0.5
41+
42+
if do_erosion:
43+
n_iter = self.randomState.randint(
44+
1, self.beta
45+
) # [1, beta)
46+
img_np[:, i, :] = ndimage.binary_erosion(
47+
img_np[:, i, :], iterations=n_iter
48+
).astype(img_np.dtype)
49+
else:
50+
n_iter = self.randomState.randint(
51+
1, self.beta
52+
) # [1, beta)
53+
img_np[:, i, :] = ndimage.binary_dilation(
54+
img_np[:, i, :], iterations=n_iter
55+
).astype(img_np.dtype)
56+
57+
return img_np
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
import torch
2+
from torch.utils.data import Dataset
3+
from abc import ABC, abstractmethod
4+
from typing import Literal, Dict, Any, Optional
5+
from jaxtyping import Float32, UInt, UInt64
6+
7+
class BaseDataset(Dataset, ABC):
8+
"""Interface for Datasets in Spine Diffusion package.
9+
10+
This interface is currently not enforced, but any dataset implementation
11+
should follow the guidelines outlined here, this is especially true for the
12+
exact returns of the __getitem__ method.
13+
"""
14+
def __init__(
15+
self,
16+
resolution: int,
17+
random_crop: bool,
18+
crop_size: int,
19+
mode: Literal["train","val","test"],
20+
**kwargs
21+
):
22+
"""Constructor of BaseDataset.
23+
24+
Args:
25+
resolution: determines base resolution of the dataset, i.e. a
26+
dataset with an original size of 256 (in 3D) will be downsampled
27+
to that resolution
28+
random_crop:
29+
"""
30+
raise NotImplementedError
31+
32+
def __getitem__(self, idx: int) -> Dict[str, Any]:
33+
"""__getitem__ method of BaseDataset.
34+
35+
Args:
36+
idx: index of desired sample
37+
38+
Returns:
39+
dictionary with keys and items as below (not all keys necessary)
40+
41+
.. code-block:: python
42+
dict(
43+
# ch corresponds to num_classes where applicable
44+
sdf: Optional[
45+
Float32[Tensor, "1 res res res"],
46+
None
47+
] = None,
48+
49+
occ: Optional[
50+
UInt64[Tensor, "1 res res res"], # with unique values in range(2, num_classes+1)
51+
] = None,
52+
53+
coords: Optional[Float32[Tensor, "num_points 3"], None] = None,
54+
targets_occ: Optional[
55+
UInt64[Tensor, "num_points"], # 2 or multi class with probabilities 1
56+
Float32[Tensor, "num_points num_classes"], # 2 or multi class with probabilities in [0,1]
57+
None
58+
] = None,
59+
targets_sdf: Optional[
60+
Float32[Tensor, "num_points"],
61+
None
62+
]
63+
loss_fn: Literal["crossentropylogits","mse"] = "crossentropylogits"
64+
metadata: Optional[Any, None] = None
65+
)
66+
67+
- "sdf" is full volume and should be normalized to [-1,1] range
68+
- "sdf_target" is cropped volume, equally normalized, may be a TSDF
69+
of the original data to enhance learning
70+
- "occ_float" full volume occupancy as torch.float32, normalized to [0,1] range
71+
- "occ_target" cropped binary/multi-class torch.long tensor
72+
- "vox_coords" contains coords of voxel centers of "sdf_target" or "occ_target",
73+
normalized to [-1,1] range (see torch.grid_sample(align_corners=True) for reference).
74+
If random_crop is False, this is not needed and will default to
75+
all voxel centers in the volume.
76+
- "rand_coords" can be used for randomized sampling of coordinates instead
77+
of voxel centers
78+
- "rand_targets" can be used for interpolated SDF values
79+
- "metadata" anything
80+
- while the channels "ch" will usually be 1, it might be good for multi-class
81+
problems to split classes between channels, should be float for
82+
"""
83+
raise NotImplementedError
84+
85+
def __len__(self):
86+
raise NotImplementedError
87+
88+
@staticmethod
89+
def check_output(output: Dict[str, Any]):
90+
pass
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
import torch
2+
from typing import List
3+
4+
def collate_fn(batch: List[dict]):
5+
res = {key: [] for key in batch[0].keys()}
6+
res["loss_fn"] = batch[0]["loss_fn"]
7+
for sample in batch:
8+
for key, elem in sample.items():
9+
if isinstance(elem, torch.Tensor):
10+
res[key].append(elem)
11+
for key, elem in res.items():
12+
if isinstance(elem, list):
13+
res[key] = torch.stack(res[key], dim=0)
14+
elif isinstance(elem, str):
15+
assert key == "loss_fn"
16+
else:
17+
raise ValueError(f"{key}")
18+
return res
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
from spine_diffusion.datasets.random_deformation_dataset import (
2+
RandomDeformationDataset
3+
)
4+
from spine_diffusion.datasets.shapenet import ShapeNet_Dataset
5+
from typing import Literal, Tuple
6+
from torch.utils.data import DataLoader
7+
from spine_diffusion.datasets.collations import collate_fn
8+
9+
def get_trainval_dataloaders(
10+
dataset: Literal["spine", "shapenet"],
11+
config: dict,
12+
batch_size: int
13+
) -> Tuple[DataLoader]:
14+
if dataset == "spine":
15+
train_ds = RandomDeformationDataset(
16+
mode = "train",
17+
**config
18+
)
19+
val_ds = RandomDeformationDataset(
20+
mode = "val",
21+
**config
22+
)
23+
elif dataset == "shapenet":
24+
train_ds = ShapeNet_Dataset(
25+
mode = "train",
26+
**config
27+
)
28+
val_ds = ShapeNet_Dataset(
29+
mode = "val",
30+
**config
31+
)
32+
else:
33+
raise ValueError("no such dataset")
34+
train_dl = DataLoader(
35+
train_ds,
36+
batch_size = batch_size,
37+
shuffle = True,
38+
num_workers = batch_size,
39+
collate_fn = collate_fn,
40+
pin_memory=True
41+
)
42+
val_dl = DataLoader(
43+
val_ds,
44+
batch_size = batch_size,
45+
shuffle = False,
46+
num_workers = batch_size,
47+
collate_fn = collate_fn,
48+
pin_memory=True
49+
)
50+
return train_dl, val_dl

0 commit comments

Comments
 (0)