Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Refactor filters into filtering #71

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 115 additions & 0 deletions src/nifreeze/data/filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,14 @@

from __future__ import annotations

import copy

import numpy as np
from scipy.ndimage import median_filter
from skimage.morphology import ball

from nifreeze.data.dmri import DEFAULT_CLIP_PERCENTILE, DWI

DEFAULT_DTYPE = "int16"
"""The default image's data type."""

Expand Down Expand Up @@ -96,3 +100,114 @@
data = np.round(255 * data).astype(dtype)

return data


def detrend_data_percentile(data: np.ndarray, mask: np.ndarray | None = None) -> np.ndarray:
r"""Detrend data.

Regresses out global signal differences so that its values are centered around the middle 90%
of the data following:

.. math::
\text{data}_{\text{detrended}} = \frac{(\text{data} - p_{5}) \cdot p_{\text{mean}}}{p_{\text{range}}} + p_{5}^{\text{mean}}

where

.. math::
p_{\text{range}} = p_{95} - p_{5}, \quad p_{\text{mean}} = \frac{1}{N} \sum_{i=1}^N p_{\text{range}_i}, \quad p_{5}^{\text{mean}} = \frac{1}{N} \sum_{i=1}^N p_{5_i}

:math:`p_{5}` and :math:`p_{95}` being the 5th percentile and the 95th percentile of the data,
respectively.

If a mask is provided, only the data within the mask are considered.

Parameters
----------
data : :obj:`~numpy.ndarray`
Data to be detrended.
mask : :obj:`~numpy.ndarray`, optional
Mask. If provided, only the data within the mask are considered.

Returns
-------
:obj:`~numpy.ndarray`
Detrended data.
"""

data = data.copy().astype("float32")
reshaped_data = data.reshape((-1, data.shape[-1])) if mask is None else data[mask]
p5 = np.percentile(reshaped_data, 5.0, axis=0)
p95 = np.percentile(reshaped_data, 95.0, axis=0) - p5
return (data - p5) * p95.mean() / p95 + p5.mean()

Check warning on line 141 in src/nifreeze/data/filtering.py

View check run for this annotation

Codecov / codecov/patch

src/nifreeze/data/filtering.py#L137-L141

Added lines #L137 - L141 were not covered by tests


def detrend_dwi_median(data: np.ndarray, mask: np.ndarray | None = None) -> np.ndarray:
"""Detrend DWI data.

Regresses out global DWI signal differences so that its standardized and centered around the
:data:`src.nifreeze.model.base.DEFAULT_CLIP_PERCENTILE` percentile.

If a mask is provided, only the data within the mask are considered.

Parameters
----------
data : :obj:`~numpy.ndarray`
Data to be detrended.
mask : :obj:`~numpy.ndarray`, optional
Mask. If provided, only the data within the mask are considered.

Returns
-------
:obj:`~numpy.ndarray`
Detrended data.
"""

shelldata = data[..., mask]

Check warning on line 165 in src/nifreeze/data/filtering.py

View check run for this annotation

Codecov / codecov/patch

src/nifreeze/data/filtering.py#L165

Added line #L165 was not covered by tests

centers = np.median(shelldata, axis=(0, 1, 2))
reference = np.percentile(centers[centers >= 1.0], DEFAULT_CLIP_PERCENTILE)
centers[centers < 1.0] = reference
drift = reference / centers
return shelldata * drift

Check warning on line 171 in src/nifreeze/data/filtering.py

View check run for this annotation

Codecov / codecov/patch

src/nifreeze/data/filtering.py#L167-L171

Added lines #L167 - L171 were not covered by tests
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Naming for above two functions/filters should be improved.



def clip_dwi_shell_data(dataset: DWI, index: int, th_low: int = 100, th_high: int = 100) -> DWI:
"""Clip DWI shell data around the given index and lower and upper b-value bounds.

Clip DWI data around the given index with the provided lower and upper bound b-values.

Parameters
----------
dataset : :obj:`~nifreeze.data.dmri.DWI`
Reference to a DWI object.
index : :obj:`int`
Index of the shell data.
th_low : :obj:`numbers.Number`, optional
A lower bound for the b-value.
th_high : :obj:`numbers.Number`, optional
An upper bound for the b-value.

Returns
-------
clipped_dataset : :obj:`~nifreeze.data.dmri.DWI`
Clipped dataset.
"""

clipped_dataset = copy.deepcopy(dataset)

Check warning on line 196 in src/nifreeze/data/filtering.py

View check run for this annotation

Codecov / codecov/patch

src/nifreeze/data/filtering.py#L196

Added line #L196 was not covered by tests

bvalues = clipped_dataset.gradients[:, -1]
bcenter = bvalues[index]

Check warning on line 199 in src/nifreeze/data/filtering.py

View check run for this annotation

Codecov / codecov/patch

src/nifreeze/data/filtering.py#L198-L199

Added lines #L198 - L199 were not covered by tests

shellmask = np.ones(len(clipped_dataset._dataset), dtype=bool)

Check warning on line 201 in src/nifreeze/data/filtering.py

View check run for this annotation

Codecov / codecov/patch

src/nifreeze/data/filtering.py#L201

Added line #L201 was not covered by tests

# Keep only bvalues within the range defined by th_high and th_low
shellmask[index] = False
shellmask[bvalues > (bcenter + th_high)] = False
shellmask[bvalues < (bcenter - th_low)] = False

Check warning on line 206 in src/nifreeze/data/filtering.py

View check run for this annotation

Codecov / codecov/patch

src/nifreeze/data/filtering.py#L204-L206

Added lines #L204 - L206 were not covered by tests

if not shellmask.sum():
raise RuntimeError(f"Shell corresponding to index {index} (b={bcenter}) is empty.")

Check warning on line 209 in src/nifreeze/data/filtering.py

View check run for this annotation

Codecov / codecov/patch

src/nifreeze/data/filtering.py#L209

Added line #L209 was not covered by tests

clipped_dataset._dataset = clipped_dataset._dataset.dataobj[..., shellmask]

Check warning on line 211 in src/nifreeze/data/filtering.py

View check run for this annotation

Codecov / codecov/patch

src/nifreeze/data/filtering.py#L211

Added line #L211 was not covered by tests

return clipped_dataset

Check warning on line 213 in src/nifreeze/data/filtering.py

View check run for this annotation

Codecov / codecov/patch

src/nifreeze/data/filtering.py#L213

Added line #L213 was not covered by tests
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we clip the data, maybe the gradients stored in the datasets should also be clipped.

50 changes: 4 additions & 46 deletions src/nifreeze/model/dmri.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,7 @@
import numpy as np
from joblib import Parallel, delayed

from nifreeze.data.dmri import (
DEFAULT_CLIP_PERCENTILE,
DTI_MIN_ORIENTATIONS,
)
from nifreeze.data.dmri import DTI_MIN_ORIENTATIONS
from nifreeze.model.base import BaseModel, ExpectationModel


Expand Down Expand Up @@ -171,64 +168,25 @@ def fit_predict(self, index, **kwargs):
class AverageDWIModel(ExpectationModel):
"""A trivial model that returns an average DWI volume."""

__slots__ = ("_th_low", "_th_high", "_detrend")

def __init__(self, dataset, stat="median", th_low=100, th_high=100, detrend=False, **kwargs):
def __init__(self, dataset, stat="median", **kwargs):
r"""
Implement object initialization.

Parameters
----------
th_low : :obj:`numbers.Number`
A lower bound for the b-value corresponding to the diffusion weighted images
that will be averaged.
th_high : :obj:`numbers.Number`
An upper bound for the b-value corresponding to the diffusion weighted images
that will be averaged.
detrend : :obj:`bool`
Whether the overall distribution of each diffusion weighted image will be
standardized and centered around the
:data:`src.nifreeze.model.base.DEFAULT_CLIP_PERCENTILE` percentile.
stat : :obj:`str`
Whether the summary statistic to apply is ``"mean"`` or ``"median"``.

"""
super().__init__(dataset, stat=stat, **kwargs)

self._th_low = th_low
self._th_high = th_high
self._detrend = detrend

def fit_predict(self, index, *_, **kwargs):
def fit_predict(self, *_, **kwargs):
"""Return the average map."""

bvalues = self._dataset.gradients[:, -1]
bcenter = bvalues[index]

shellmask = np.ones(len(self._dataset), dtype=bool)

# Keep only bvalues within the range defined by th_high and th_low
shellmask[index] = False
shellmask[bvalues > (bcenter + self._th_high)] = False
shellmask[bvalues < (bcenter - self._th_low)] = False

if not shellmask.sum():
raise RuntimeError(f"Shell corresponding to index {index} (b={bcenter}) is empty.")

shelldata = self._dataset.dataobj[..., shellmask]

# Regress out global signal differences
if self._detrend:
centers = np.median(shelldata, axis=(0, 1, 2))
reference = np.percentile(centers[centers >= 1.0], DEFAULT_CLIP_PERCENTILE)
centers[centers < 1.0] = reference
drift = reference / centers
shelldata = shelldata * drift

# Select the summary statistic
avg_func = np.median if self._stat == "median" else np.mean
# Calculate the average
return avg_func(shelldata, axis=-1)
return avg_func(self._dataset.dataobj, axis=-1)


class DTIModel(BaseDWIModel):
Expand Down
Loading