Skip to content

Commit

Permalink
espirit stub from previous hackathon
Browse files Browse the repository at this point in the history
  • Loading branch information
ckolbPTB committed Jul 3, 2024
1 parent 3912033 commit 4424243
Show file tree
Hide file tree
Showing 3 changed files with 115 additions and 0 deletions.
1 change: 1 addition & 0 deletions src/mrpro/algorithms/csm/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from mrpro.algorithms.csm.iterative_walsh import iterative_walsh
from mrpro.algorithms.csm.espirit import espirit
71 changes: 71 additions & 0 deletions src/mrpro/algorithms/csm/espirit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
"""ESPIRIT method for coil sensitivity map calculation."""

# Copyright 2024 Physikalisch-Technische Bundesanstalt
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at:
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch

from einops import rearrange


def espirit(
calib: torch.Tensor,
img_shape,
thresh=0.02,
kernel_width=6,
crop=0.95,
max_iter=10,
):
# inspired by https://sigpy.readthedocs.io/en/latest/_modules/sigpy/mri/app.html#EspiritCalib

# Get calibration matrix.
# Shape [num_coils] + num_blks + [kernel_width] * img_ndim
mat = calib
for ax in (1, 2, 3):
mat = mat.unfold(dimension=ax, size=min(calib.shape[ax], kernel_width), step=1)
num_coils, _, _, _, c, b, a = mat.shape
mat = rearrange(mat, 'coils z y x c b a -> (z y x) (coils c b a)')

# Perform SVD on calibration matrix
_, S, VH = torch.linalg.svd(mat, full_matrices=False)

# Get kernels
VH = torch.diag((S > thresh * S.max()).type(VH.type())) @ VH
kernels = rearrange(VH, 'n (coils c b a) -> n coils c b a', coils=num_coils, c=c, b=b, a=a)

# Get covariance matrix in image domain
AHA = torch.zeros((num_coils, num_coils, *img_shape), dtype=calib.dtype, device=calib.device)

for kernel in kernels:
img_kernel = torch.fft.ifftn(kernel, s=img_shape, dim=(-3, -2, -1))
img_kernel = torch.fft.ifftshift(img_kernel, dim=(-1, -2, -3))
AHA += torch.einsum('c z y x, d z y x->c d z y x ', img_kernel, img_kernel.conj())

AHA *= AHA[0, 0].numel() / kernels.shape[-1]

v = AHA.sum(dim=0)
for _ in range(max_iter):
v /= v.norm(dim=0)
v = torch.einsum('abzyx,bzyx->azyx', AHA, v)
max_eig = v.norm(dim=0)
print(max_eig.max())
csm = v / max_eig

# Normalize phase with respect to first channel
csm *= csm[0].conj() / csm[0].abs()

# Crop maps by thresholding eigenvalue
csm *= max_eig #> crop

return csm
43 changes: 43 additions & 0 deletions src/mrpro/data/CsmData.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,11 @@

import torch

from mrpro.data._kdata.KData import KData
from mrpro.data.IData import IData
from mrpro.data.QData import QData
from mrpro.data.SpatialDimension import SpatialDimension
from mrpro.utils import smap

if TYPE_CHECKING:
from mrpro.operators.SensitivityOp import SensitivityOp
Expand Down Expand Up @@ -67,6 +69,47 @@ def from_idata_walsh(
csm = cls(header=idata.header, data=csm_tensor)
return csm

@classmethod
def from_kdata_espirit(
cls,
kdata: KData,
thresh: float = 0.02,
kernel_width: int = 6,
max_iter: int = 10,
crop: float = 0.95,
chunk_size_otherdim=None,
) -> CsmData:
"""Espirit sensitivity Estimation (DRAFT)
Works only for Cartesian K Data
Parameters
----------
kdata
_description_
chunk_size_otherdim, optional
_description_, by default None
"""
from mrpro.algorithms.csm.espirit import espirit
# kdata.data = kdata.data.repeat(2,1,1,1,1)

# check for cartesian
# get calib
_, _, nz, ny, nx = kdata.data.shape
blen = 10
nz_l, nz_u = (nz - blen) // 2, (nz + blen) // 2
ny_l, ny_u = (ny - blen) // 2, (ny + blen) // 2
nx_l, nx_u = (nx - blen) // 2, (nx + blen) // 2
calib = kdata.data[:, :, nz_l:nz_u, ny_l:ny_u, nx_l:nx_u]
img_shape = kdata.data.shape[-3:]

csm_fun = lambda c: espirit(
c, img_shape=img_shape, thresh=thresh, kernel_width=kernel_width, max_iter=max_iter, crop=crop
)
csm_data = smap(csm_fun, calib, passed_dimensions=(1, 2, 3, 4))
return cls(header=kdata.header, data=csm_data)

def as_operator(self) -> SensitivityOp:
"""Create SensitivityOp using a copy of the CSMs."""
from mrpro.operators.SensitivityOp import SensitivityOp
Expand Down

0 comments on commit 4424243

Please sign in to comment.