diff --git a/torchkbnufft/functional/interp.py b/torchkbnufft/functional/interp.py index ce003d6..bc502e9 100644 --- a/torchkbnufft/functional/interp.py +++ b/torchkbnufft/functional/interp.py @@ -1,4 +1,4 @@ -from typing import List, Tuple +from typing import List, Tuple, Optional import torch from torch import Tensor @@ -87,6 +87,7 @@ def kb_table_interp( numpoints: Tensor, table_oversamp: Tensor, offsets: Tensor, + ssbasis: Optional[Tensor] = None, ) -> Tensor: """Kaiser-Bessel table interpolation. @@ -113,6 +114,9 @@ def kb_table_interp( is_complex = False image = torch.view_as_complex(image) + + if ssbasis is not None: + image = torch.tensordot(ssbasis,image,dims=([1], [0])) data = KbTableInterpForward.apply( image, omega, tables, n_shift, numpoints, table_oversamp, offsets @@ -133,6 +137,7 @@ def kb_table_interp_adjoint( table_oversamp: Tensor, offsets: Tensor, grid_size: Tensor, + ssbasis: Optional[Tensor] = None, ) -> Tensor: """Kaiser-Bessel table interpolation adjoint. @@ -169,4 +174,8 @@ def kb_table_interp_adjoint( if is_complex is False: image = torch.view_as_real(image) + if ssbasis is not None: + image = torch.tensordot(ssbasis.conj(),image,dims=([0], [0])) + + return image diff --git a/torchkbnufft/functional/nufft.py b/torchkbnufft/functional/nufft.py index 297f275..866cfe8 100644 --- a/torchkbnufft/functional/nufft.py +++ b/torchkbnufft/functional/nufft.py @@ -135,6 +135,7 @@ def kb_table_nufft( table_oversamp: Tensor, offsets: Tensor, norm: Optional[str] = None, + ssbasis: Optional[Tensor] = None, ) -> Tensor: """Kaiser-Bessel NUFFT with table interpolation. @@ -183,6 +184,7 @@ def kb_table_nufft( numpoints=numpoints, table_oversamp=table_oversamp, offsets=offsets, + ssbasis=ssbasis, ) if is_complex is False: @@ -203,6 +205,7 @@ def kb_table_nufft_adjoint( table_oversamp: Tensor, offsets: Tensor, norm: Optional[str] = None, + ssbasis: Optional[Tensor] = None, ) -> Tensor: """Kaiser-Bessel NUFFT adjoint with table interpolation. @@ -247,6 +250,7 @@ def kb_table_nufft_adjoint( table_oversamp=table_oversamp, offsets=offsets, grid_size=grid_size, + ssbasis=ssbasis, ), scaling_coef=scaling_coef, im_size=im_size, diff --git a/torchkbnufft/modules/kbnufft.py b/torchkbnufft/modules/kbnufft.py index b91ea0c..4570e61 100644 --- a/torchkbnufft/modules/kbnufft.py +++ b/torchkbnufft/modules/kbnufft.py @@ -129,6 +129,7 @@ def forward( interp_mats: Optional[Tuple[Tensor, Tensor]] = None, smaps: Optional[Tensor] = None, norm: Optional[str] = None, + ssbasis: Optional[Tensor] = None, ) -> Tensor: """Apply FFT and interpolate from gridded data to scattered data. @@ -220,6 +221,7 @@ def forward( table_oversamp=self.table_oversamp, offsets=self.offsets.to(torch.long), norm=norm, + ssbasis=ssbasis, ) if not is_complex: @@ -311,6 +313,7 @@ def forward( interp_mats: Optional[Tuple[Tensor, Tensor]] = None, smaps: Optional[Tensor] = None, norm: Optional[str] = None, + ssbasis: Optional[Tensor] = None, ) -> Tensor: """Interpolate from scattered data to gridded data and then iFFT. @@ -399,6 +402,7 @@ def forward( table_oversamp=self.table_oversamp, offsets=self.offsets.to(torch.long), norm=norm, + ssbasis=ssbasis, ) if smaps is not None: