Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Subspace in Batch Dimension #92

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion torchkbnufft/functional/interp.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Tuple
from typing import List, Tuple, Optional

import torch
from torch import Tensor
Expand Down Expand Up @@ -87,6 +87,7 @@ def kb_table_interp(
numpoints: Tensor,
table_oversamp: Tensor,
offsets: Tensor,
ssbasis: Optional[Tensor] = None,
) -> Tensor:
"""Kaiser-Bessel table interpolation.

Expand All @@ -113,6 +114,9 @@ def kb_table_interp(

is_complex = False
image = torch.view_as_complex(image)

if ssbasis is not None:
image = torch.tensordot(ssbasis,image,dims=([1], [0]))

data = KbTableInterpForward.apply(
image, omega, tables, n_shift, numpoints, table_oversamp, offsets
Expand All @@ -133,6 +137,7 @@ def kb_table_interp_adjoint(
table_oversamp: Tensor,
offsets: Tensor,
grid_size: Tensor,
ssbasis: Optional[Tensor] = None,
) -> Tensor:
"""Kaiser-Bessel table interpolation adjoint.

Expand Down Expand Up @@ -169,4 +174,8 @@ def kb_table_interp_adjoint(
if is_complex is False:
image = torch.view_as_real(image)

if ssbasis is not None:
image = torch.tensordot(ssbasis.conj(),image,dims=([0], [0]))


return image
4 changes: 4 additions & 0 deletions torchkbnufft/functional/nufft.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ def kb_table_nufft(
table_oversamp: Tensor,
offsets: Tensor,
norm: Optional[str] = None,
ssbasis: Optional[Tensor] = None,
) -> Tensor:
"""Kaiser-Bessel NUFFT with table interpolation.

Expand Down Expand Up @@ -183,6 +184,7 @@ def kb_table_nufft(
numpoints=numpoints,
table_oversamp=table_oversamp,
offsets=offsets,
ssbasis=ssbasis,
)

if is_complex is False:
Expand All @@ -203,6 +205,7 @@ def kb_table_nufft_adjoint(
table_oversamp: Tensor,
offsets: Tensor,
norm: Optional[str] = None,
ssbasis: Optional[Tensor] = None,
) -> Tensor:
"""Kaiser-Bessel NUFFT adjoint with table interpolation.

Expand Down Expand Up @@ -247,6 +250,7 @@ def kb_table_nufft_adjoint(
table_oversamp=table_oversamp,
offsets=offsets,
grid_size=grid_size,
ssbasis=ssbasis,
),
scaling_coef=scaling_coef,
im_size=im_size,
Expand Down
4 changes: 4 additions & 0 deletions torchkbnufft/modules/kbnufft.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ def forward(
interp_mats: Optional[Tuple[Tensor, Tensor]] = None,
smaps: Optional[Tensor] = None,
norm: Optional[str] = None,
ssbasis: Optional[Tensor] = None,
) -> Tensor:
"""Apply FFT and interpolate from gridded data to scattered data.

Expand Down Expand Up @@ -220,6 +221,7 @@ def forward(
table_oversamp=self.table_oversamp,
offsets=self.offsets.to(torch.long),
norm=norm,
ssbasis=ssbasis,
)

if not is_complex:
Expand Down Expand Up @@ -311,6 +313,7 @@ def forward(
interp_mats: Optional[Tuple[Tensor, Tensor]] = None,
smaps: Optional[Tensor] = None,
norm: Optional[str] = None,
ssbasis: Optional[Tensor] = None,
) -> Tensor:
"""Interpolate from scattered data to gridded data and then iFFT.

Expand Down Expand Up @@ -399,6 +402,7 @@ def forward(
table_oversamp=self.table_oversamp,
offsets=self.offsets.to(torch.long),
norm=norm,
ssbasis=ssbasis,
)

if smaps is not None:
Expand Down