-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
13 changed files
with
2,571 additions
and
24 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 22, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from diffusion_models.utils.datasets import LumbarSpineDataset\n", | ||
"\n", | ||
"ds = LumbarSpineDataset()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 23, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"163584" | ||
] | ||
}, | ||
"execution_count": 23, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"len(ds)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "liotorch", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.11.5" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
"""Module containing the augmentation class for random erosion and dilation.""" | ||
import numpy as np | ||
import scipy.ndimage as ndimage | ||
|
||
|
||
class RandomErosion: | ||
"""Random erosion augmentation class.""" | ||
|
||
def __init__( | ||
self, randomState: np.random.RandomState, alpha=0.66, beta=5 | ||
) -> None: | ||
"""Initialize the random erosion augmentation class. | ||
Randomly erodes/dilates the image with a probability of alpha and a maximum number of iterations of beta. | ||
Args: | ||
randomState (np.random.RandomState): randomstate object to use for random number generation | ||
alpha (float, optional): Hyperparameter alpha, probability of doing augmentation. Defaults to 0.66. | ||
beta (int, optional): Hyperparameter beta, maximum number of erosion/dilation iterations. Defaults to 5. | ||
""" | ||
self.alpha = alpha | ||
self.beta = beta | ||
self.randomState = randomState | ||
|
||
def __call__(self, img_np: np.ndarray) -> np.ndarray: | ||
"""Apply the augmentation to the image. | ||
Args: | ||
img_np (np.ndarray): image to augment | ||
Returns: | ||
np.ndarray: augmented image | ||
""" | ||
img_np = np.where(img_np != 0, 1, 0).astype(img_np.dtype) | ||
|
||
for i in range(img_np.shape[1]): | ||
do_augment = self.randomState.rand() < self.alpha | ||
|
||
if do_augment: | ||
do_erosion = self.randomState.rand() < 0.5 | ||
|
||
if do_erosion: | ||
n_iter = self.randomState.randint( | ||
1, self.beta | ||
) # [1, beta) | ||
img_np[:, i, :] = ndimage.binary_erosion( | ||
img_np[:, i, :], iterations=n_iter | ||
).astype(img_np.dtype) | ||
else: | ||
n_iter = self.randomState.randint( | ||
1, self.beta | ||
) # [1, beta) | ||
img_np[:, i, :] = ndimage.binary_dilation( | ||
img_np[:, i, :], iterations=n_iter | ||
).astype(img_np.dtype) | ||
|
||
return img_np |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
import torch | ||
from torch.utils.data import Dataset | ||
from abc import ABC, abstractmethod | ||
from typing import Literal, Dict, Any, Optional | ||
from jaxtyping import Float32, UInt, UInt64 | ||
|
||
class BaseDataset(Dataset, ABC): | ||
"""Interface for Datasets in Spine Diffusion package. | ||
This interface is currently not enforced, but any dataset implementation | ||
should follow the guidelines outlined here, this is especially true for the | ||
exact returns of the __getitem__ method. | ||
""" | ||
def __init__( | ||
self, | ||
resolution: int, | ||
random_crop: bool, | ||
crop_size: int, | ||
mode: Literal["train","val","test"], | ||
**kwargs | ||
): | ||
"""Constructor of BaseDataset. | ||
Args: | ||
resolution: determines base resolution of the dataset, i.e. a | ||
dataset with an original size of 256 (in 3D) will be downsampled | ||
to that resolution | ||
random_crop: | ||
""" | ||
raise NotImplementedError | ||
|
||
def __getitem__(self, idx: int) -> Dict[str, Any]: | ||
"""__getitem__ method of BaseDataset. | ||
Args: | ||
idx: index of desired sample | ||
Returns: | ||
dictionary with keys and items as below (not all keys necessary) | ||
.. code-block:: python | ||
dict( | ||
# ch corresponds to num_classes where applicable | ||
sdf: Optional[ | ||
Float32[Tensor, "1 res res res"], | ||
None | ||
] = None, | ||
occ: Optional[ | ||
UInt64[Tensor, "1 res res res"], # with unique values in range(2, num_classes+1) | ||
] = None, | ||
coords: Optional[Float32[Tensor, "num_points 3"], None] = None, | ||
targets_occ: Optional[ | ||
UInt64[Tensor, "num_points"], # 2 or multi class with probabilities 1 | ||
Float32[Tensor, "num_points num_classes"], # 2 or multi class with probabilities in [0,1] | ||
None | ||
] = None, | ||
targets_sdf: Optional[ | ||
Float32[Tensor, "num_points"], | ||
None | ||
] | ||
loss_fn: Literal["crossentropylogits","mse"] = "crossentropylogits" | ||
metadata: Optional[Any, None] = None | ||
) | ||
- "sdf" is full volume and should be normalized to [-1,1] range | ||
- "sdf_target" is cropped volume, equally normalized, may be a TSDF | ||
of the original data to enhance learning | ||
- "occ_float" full volume occupancy as torch.float32, normalized to [0,1] range | ||
- "occ_target" cropped binary/multi-class torch.long tensor | ||
- "vox_coords" contains coords of voxel centers of "sdf_target" or "occ_target", | ||
normalized to [-1,1] range (see torch.grid_sample(align_corners=True) for reference). | ||
If random_crop is False, this is not needed and will default to | ||
all voxel centers in the volume. | ||
- "rand_coords" can be used for randomized sampling of coordinates instead | ||
of voxel centers | ||
- "rand_targets" can be used for interpolated SDF values | ||
- "metadata" anything | ||
- while the channels "ch" will usually be 1, it might be good for multi-class | ||
problems to split classes between channels, should be float for | ||
""" | ||
raise NotImplementedError | ||
|
||
def __len__(self): | ||
raise NotImplementedError | ||
|
||
@staticmethod | ||
def check_output(output: Dict[str, Any]): | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
import torch | ||
from typing import List | ||
|
||
def collate_fn(batch: List[dict]): | ||
res = {key: [] for key in batch[0].keys()} | ||
res["loss_fn"] = batch[0]["loss_fn"] | ||
for sample in batch: | ||
for key, elem in sample.items(): | ||
if isinstance(elem, torch.Tensor): | ||
res[key].append(elem) | ||
for key, elem in res.items(): | ||
if isinstance(elem, list): | ||
res[key] = torch.stack(res[key], dim=0) | ||
elif isinstance(elem, str): | ||
assert key == "loss_fn" | ||
else: | ||
raise ValueError(f"{key}") | ||
return res |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
from spine_diffusion.datasets.random_deformation_dataset import ( | ||
RandomDeformationDataset | ||
) | ||
from spine_diffusion.datasets.shapenet import ShapeNet_Dataset | ||
from typing import Literal, Tuple | ||
from torch.utils.data import DataLoader | ||
from spine_diffusion.datasets.collations import collate_fn | ||
|
||
def get_trainval_dataloaders( | ||
dataset: Literal["spine", "shapenet"], | ||
config: dict, | ||
batch_size: int | ||
) -> Tuple[DataLoader]: | ||
if dataset == "spine": | ||
train_ds = RandomDeformationDataset( | ||
mode = "train", | ||
**config | ||
) | ||
val_ds = RandomDeformationDataset( | ||
mode = "val", | ||
**config | ||
) | ||
elif dataset == "shapenet": | ||
train_ds = ShapeNet_Dataset( | ||
mode = "train", | ||
**config | ||
) | ||
val_ds = ShapeNet_Dataset( | ||
mode = "val", | ||
**config | ||
) | ||
else: | ||
raise ValueError("no such dataset") | ||
train_dl = DataLoader( | ||
train_ds, | ||
batch_size = batch_size, | ||
shuffle = True, | ||
num_workers = batch_size, | ||
collate_fn = collate_fn, | ||
pin_memory=True | ||
) | ||
val_dl = DataLoader( | ||
val_ds, | ||
batch_size = batch_size, | ||
shuffle = False, | ||
num_workers = batch_size, | ||
collate_fn = collate_fn, | ||
pin_memory=True | ||
) | ||
return train_dl, val_dl |
Oops, something went wrong.