Skip to content

Commit

Permalink
ENH: Refactor filters into filtering
Browse files Browse the repository at this point in the history
Refactor code blocks that perform filtering operations within models
into the `filtering` module so that the model and filter concepts are
separated, supporting more cleanly pipelining models with filters within
the new `Estimator` philosophy.

Add the percentile-based detrending feature that was temporarily removed
in commit 7322e93.
  • Loading branch information
jhlegarreta committed Jan 25, 2025
1 parent 7237ecf commit 7e67f76
Show file tree
Hide file tree
Showing 2 changed files with 122 additions and 42 deletions.
119 changes: 119 additions & 0 deletions src/nifreeze/data/filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,15 @@

from __future__ import annotations

import copy

import numpy as np
from scipy.ndimage import median_filter
from skimage.morphology import ball

from nifreeze.data.dmri import DEFAULT_CLIP_PERCENTILE, DWI


DEFAULT_DTYPE = "int16"
"""The default image's data type."""

Expand Down Expand Up @@ -96,3 +101,117 @@ def advanced_clip(
data = np.round(255 * data).astype(dtype)

return data


def detrend_data_percentile(data : np.ndarray, mask : np.ndarray | None = None) -> np.ndarray:
"""Detrend data.
Regresses out global signal differences so that its values are centered around the middle 90%
of the data following:
.. math::
\text{data}_{\text{detrended}} = \frac{(\text{data} - p_{5}) \cdot p_{\text{mean}}}{p_{\text{range}}} + p_{5}^{\text{mean}}
where
.. math::
p_{\text{range}} = p_{95} - p_{5}, \quad p_{\text{mean}} = \frac{1}{N} \sum_{i=1}^N p_{\text{range}_i}, \quad p_{5}^{\text{mean}} = \frac{1}{N} \sum_{i=1}^N p_{5_i}
:math:`p_{5}` and :math:`p_{95}` being the 5th percentile and the 95th percentile of the data,
respectively.
If a mask is provided, only the data within the mask are considered.
Parameters
----------
data : :obj:`~numpy.ndarray`
Data to be detrended.
mask : :obj:`~numpy.ndarray`, optional
Mask. If provided, only the data within the mask are considered.
Returns
-------
:obj:`~numpy.ndarray`
Detrended data.
"""

data = data.copy().astype("float32")
reshaped_data = (
data.reshape((-1, data.shape[-1])) if mask is None else data[mask]
)
p5 = np.percentile(reshaped_data, 5.0, axis=0)
p95 = np.percentile(reshaped_data, 95.0, axis=0) - p5
return (data - p5) * p95.mean() / p95 + p5.mean()


def detrend_dwi_median(data : np.ndarray, mask : np.ndarray | None = None) -> np.ndarray:
"""Detrend DWI data.
Regresses out global DWI signal differences so that its standardized and centered around the
:data:`src.nifreeze.model.base.DEFAULT_CLIP_PERCENTILE` percentile.
If a mask is provided, only the data within the mask are considered.
Parameters
----------
data : :obj:`~numpy.ndarray`
Data to be detrended.
mask : :obj:`~numpy.ndarray`, optional
Mask. If provided, only the data within the mask are considered.
Returns
-------
:obj:`~numpy.ndarray`
Detrended data.
"""

shelldata = data[..., mask]

centers = np.median(shelldata, axis=(0, 1, 2))
reference = np.percentile(centers[centers >= 1.0], DEFAULT_CLIP_PERCENTILE)
centers[centers < 1.0] = reference
drift = reference / centers
return shelldata * drift


def clip_dwi_shell_data(dataset: DWI, index : int, th_low : int = 100, th_high : int = 100) -> DWI:
"""Clip DWI shell data around the given index and lower and upper b-value bounds.
Clip DWI data around the given index with the provided lower and upper bound b-values.
Parameters
----------
dataset : :obj:`~nifreeze.data.dmri.DWI`
Reference to a DWI object.
index : :obj:`int`
Index of the shell data.
th_low : :obj:`numbers.Number`, optional
A lower bound for the b-value.
th_high : :obj:`numbers.Number`, optional
An upper bound for the b-value.
Returns
-------
clipped_dataset : :obj:`~nifreeze.data.dmri.DWI`
Clipped dataset.
"""

clipped_dataset = copy.deepcopy(dataset)

bvalues = clipped_dataset.gradients[:, -1]
bcenter = bvalues[index]

shellmask = np.ones(len(clipped_dataset._dataset), dtype=bool)

# Keep only bvalues within the range defined by th_high and th_low
shellmask[index] = False
shellmask[bvalues > (bcenter + th_high)] = False
shellmask[bvalues < (bcenter - th_low)] = False

if not shellmask.sum():
raise RuntimeError(
f"Shell corresponding to index {index} (b={bcenter}) is empty.")

clipped_dataset._dataset = clipped_dataset._dataset.dataobj[..., shellmask]

return clipped_dataset
45 changes: 3 additions & 42 deletions src/nifreeze/model/dmri.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,64 +171,25 @@ def fit_predict(self, index, **kwargs):
class AverageDWIModel(ExpectationModel):
"""A trivial model that returns an average DWI volume."""

__slots__ = ("_th_low", "_th_high", "_detrend")

def __init__(self, dataset, stat="median", th_low=100, th_high=100, detrend=False, **kwargs):
def __init__(self, dataset, stat="median", **kwargs):
r"""
Implement object initialization.
Parameters
----------
th_low : :obj:`numbers.Number`
A lower bound for the b-value corresponding to the diffusion weighted images
that will be averaged.
th_high : :obj:`numbers.Number`
An upper bound for the b-value corresponding to the diffusion weighted images
that will be averaged.
detrend : :obj:`bool`
Whether the overall distribution of each diffusion weighted image will be
standardized and centered around the
:data:`src.nifreeze.model.base.DEFAULT_CLIP_PERCENTILE` percentile.
stat : :obj:`str`
Whether the summary statistic to apply is ``"mean"`` or ``"median"``.
"""
super().__init__(dataset, stat=stat, **kwargs)

self._th_low = th_low
self._th_high = th_high
self._detrend = detrend

def fit_predict(self, index, *_, **kwargs):
def fit_predict(self, *_, **kwargs):
"""Return the average map."""

bvalues = self._dataset.gradients[:, -1]
bcenter = bvalues[index]

shellmask = np.ones(len(self._dataset), dtype=bool)

# Keep only bvalues within the range defined by th_high and th_low
shellmask[index] = False
shellmask[bvalues > (bcenter + self._th_high)] = False
shellmask[bvalues < (bcenter - self._th_low)] = False

if not shellmask.sum():
raise RuntimeError(f"Shell corresponding to index {index} (b={bcenter}) is empty.")

shelldata = self._dataset.dataobj[..., shellmask]

# Regress out global signal differences
if self._detrend:
centers = np.median(shelldata, axis=(0, 1, 2))
reference = np.percentile(centers[centers >= 1.0], DEFAULT_CLIP_PERCENTILE)
centers[centers < 1.0] = reference
drift = reference / centers
shelldata = shelldata * drift

# Select the summary statistic
avg_func = np.median if self._stat == "median" else np.mean
# Calculate the average
return avg_func(shelldata, axis=-1)
return avg_func(self._dataset.dataobj, axis=-1)


class DTIModel(BaseDWIModel):
Expand Down

0 comments on commit 7e67f76

Please sign in to comment.