Skip to content

Commit

Permalink
Batched nufft (#24)
Browse files Browse the repository at this point in the history
* First batched nufft implementation

* Performance improvements

* Fixes for toep

* Type fix

* Bug fix, doc updates

* Fix dcomp it num

* Update batch docs

* Add new docs

* Change the performance tips name

* Try to change list

* Increment version

* Fix perf tips

* Code quality, remove test script

* Update performance doc

* Update doc

* Update docs
  • Loading branch information
mmuckley authored Feb 16, 2021
1 parent b5f6581 commit abca3f2
Show file tree
Hide file tree
Showing 15 changed files with 967 additions and 136 deletions.
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ References
:caption: User Guide

basic
performance

.. toctree::
:hidden:
Expand Down
134 changes: 134 additions & 0 deletions docs/source/performance.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
Performance Tips
================

:py:mod:`torchkbnufft` is primarily written for the goal of scaling parallelism within
the PyTorch framework. The performance bottleneck of the package comes from two sources:
1) advanced indexing and 2) multiplications. Multiplications are handled in a way that
scales well, but advanced indexing is not due to
`limitations with PyTorch <https://github.com/pytorch/pytorch/issues/29973>`_.
As a result, growth in problem size that is independent of the indexing bottleneck is
handled very well by the package, such as:

1. Scaling the batch dimension.
2. Scaling the coil dimension.

Generally, you can just add to these dimensions and the package will perform well
without adding much compute time. If you're chasing more speed, some strategies that
might be helpful are listed below.

Using Batched K-space Trajectories
----------------------------------

As of version ``1.1.0``, :py:mod:`torchkbnufft` can use batched k-space trajectories.
If you pass in a variable for ``omega`` with dimensions
``(N, length(im_size), klength)``, the package will parallelize the execution of all
trajectories in the ``N`` dimension. This is useful when ``N`` is very large, as might
occur in dynamic imaging settings. The following shows an example:

.. code-block:: python
import torch
import torchkbnufft as tkbn
import numpy as np
from skimage.data import shepp_logan_phantom
batch_size = 12
x = shepp_logan_phantom().astype(np.complex)
im_size = x.shape
# convert to tensor, unsqueeze batch and coil dimension
# output size: (batch_size, 1, ny, nx)
x = torch.tensor(x).unsqueeze(0).unsqueeze(0).to(torch.complex64)
x = x.repeat(batch_size, 1, 1, 1)
klength = 64
ktraj = np.stack(
(np.zeros(64), np.linspace(-np.pi, np.pi, klength))
)
# convert to tensor, unsqueeze batch dimension
# output size: (batch_size, 2, klength)
ktraj = torch.tensor(ktraj).to(torch.float)
ktraj = ktraj.unsqueeze(0).repeat(batch_size, 1, 1)
nufft_ob = tkbn.KbNufft(im_size=im_size)
# outputs a (batch_size, 1, klength) vector of k-space data
kdata = nufft_ob(x, ktraj)
This code will then compute the 12 different radial spokes while parallelizing as much
as possible.

Lowering the Precision
----------------------

A simple way to save both memory and compute time is to decrease the precision. PyTorch
normally operates at a default 32-bit floating point precision, but if you're converting
data from NumPy then you might have some data at 64-bit floating precision. To use
32-bit precision, simply do the following:

.. code-block:: python
image = image.to(dtype=torch.complex64)
ktraj = ktraj.to(dtype=torch.float32)
forw_ob = forw_ob.to(image)
data = forw_ob(image, ktraj)
The ``forw_ob.to(image)`` command will automagically determine the type for both real
and complex tensors registered as buffers under ``forw_ob``, so you should be able to
do this safely in your code.

In many cases, the tradeoff for going from 64-bit to 32-bit is not severe, so you can
securely use 32-bit precision.

Lowering the Oversampling Ratio
-------------------------------

If you create a :py:class:`~torchkbnufft.KbNufft` object using the following code:

.. code-block:: python
forw_ob = tkbn.KbNufft(im_size=im_size)
then by default it will use a 2-factor oversampled grid. For some applications, this can
be overkill. If you can sacrifice some accuracy for your application, you can use a
smaller grid with 1.25-factor oversampling by altering how you initialize NUFFT objects
like :py:class:`~torchkbnufft.KbNufft`:

.. code-block:: python
grid_size = tuple([int(el * 1.25) for el in im_size])
forw_ob = tkbn.KbNufft(im_size=im_size, grid_size=grid_size)
Using Fewer Interpolation Neighbors
-----------------------------------

Another major speed factor is how many neighbors you use for interpolation. By default,
:py:mod:`torchkbnufft` uses 6 nearest neighbors in each dimension. If you can sacrifice
accuracy, you can get more speed by using fewer neighbors by altering how you initialize
NUFFT objects like :py:class:`~torchkbnufft.KbNufft`:

.. code-block:: python
forw_ob = tkbn.KbNufft(im_size=im_size, numpoints=4)
If you know that you can be less accurate in one dimension (e.g., the z-dimension), then
you can use less neighbors in only that dimension:

.. code-block:: python
forw_ob = tkbn.KbNufft(im_size=im_size, numpoints=(4, 6, 6))
Package Limitations
-------------------

As mentioned earlier, batches and coils scale well, primarily due to the fact that they
don't impact the bottlenecks of the package around advanced indexing. Where
:py:mod:`torchkbnufft` does not scale well is:

1. Very long k-space trajectories.
2. More imaging dimensions (e.g., 3D).

For these settings, you can first try to use some of the strategies here (lowering
precision, fewer neighbors, smaller grid). In some cases, lowering the precision a bit
and using a GPU can still give strong performance. If you're still waiting too long for
compute after trying all of these, you may be running into the limits of the package.
38 changes: 38 additions & 0 deletions tests/test_dcomp.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import numpy as np
import pytest
import torch
import torchkbnufft as tkbn
Expand Down Expand Up @@ -37,3 +38,40 @@ def test_dcomp_run(shape, kdata_shape, is_complex):
_ = adj_ob(kdata * dcomp, ktraj)

torch.set_default_dtype(default_dtype)


@pytest.mark.parametrize(
"shape, kdata_shape",
[
([2, 1, 19], [2, 1, 25]),
([3, 1, 13], [3, 1, 18]),
([6, 1, 32, 16], [6, 1, 83]),
([5, 1, 15, 12], [5, 1, 83]),
([3, 2, 13, 18, 12], [3, 2, 112]),
([2, 2, 17, 19, 12], [2, 2, 112]),
],
)
def test_batched_dcomp(shape, kdata_shape):
default_dtype = torch.get_default_dtype()
torch.set_default_dtype(torch.double)
torch.manual_seed(123)
im_size = shape[2:]

ktraj = (
torch.rand(size=(shape[0], len(im_size), kdata_shape[2])) * 2 * np.pi - np.pi
)

forloop_dcomp = []
for ktraj_it in ktraj:
res = tkbn.calc_density_compensation_function(ktraj=ktraj_it, im_size=im_size)
forloop_dcomp.append(
tkbn.calc_density_compensation_function(ktraj=ktraj_it, im_size=im_size)
)

batched_dcomp = tkbn.calc_density_compensation_function(
ktraj=ktraj, im_size=im_size
)

assert torch.allclose(torch.cat(forloop_dcomp), batched_dcomp)

torch.set_default_dtype(default_dtype)
49 changes: 49 additions & 0 deletions tests/test_interp.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pickle

import numpy as np
import pytest
import torch
import torchkbnufft as tkbn
Expand Down Expand Up @@ -277,3 +278,51 @@ def test_interp_autograd_gpu(shape, kdata_shape, is_complex):
nufft_autograd_test(image, kdata, ktraj, forw_ob, adj_ob, spmat)

torch.set_default_dtype(default_dtype)


@pytest.mark.parametrize(
"shape, kdata_shape, is_complex",
[
([3, 1, 19], [3, 1, 25], True),
([3, 1, 13, 2], [3, 1, 18, 2], False),
([4, 1, 32, 16], [4, 1, 83], True),
([5, 1, 15, 12, 2], [5, 1, 83, 2], False),
([3, 2, 13, 18, 12], [3, 2, 112], True),
([2, 2, 17, 19, 12, 2], [2, 2, 112, 2], False),
],
)
def test_interp_batches(shape, kdata_shape, is_complex):
default_dtype = torch.get_default_dtype()
torch.set_default_dtype(torch.double)
torch.manual_seed(123)
if is_complex:
im_size = shape[2:]
else:
im_size = shape[2:-1]

image = create_input_plus_noise(shape, is_complex)
kdata = create_input_plus_noise(kdata_shape, is_complex)
ktraj = (
torch.rand(size=(shape[0], len(im_size), kdata_shape[2])) * 2 * np.pi - np.pi
)

forw_ob = tkbn.KbInterp(im_size=im_size, grid_size=im_size)
adj_ob = tkbn.KbInterpAdjoint(im_size=im_size, grid_size=im_size)

forloop_test_forw = []
for image_it, ktraj_it in zip(image, ktraj):
forloop_test_forw.append(forw_ob(image_it.unsqueeze(0), ktraj_it))

batched_test_forw = forw_ob(image, ktraj)

assert torch.allclose(torch.cat(forloop_test_forw), batched_test_forw)

forloop_test_adj = []
for data_it, ktraj_it in zip(kdata, ktraj):
forloop_test_adj.append(adj_ob(data_it.unsqueeze(0), ktraj_it))

batched_test_adj = adj_ob(kdata, ktraj)

assert torch.allclose(torch.cat(forloop_test_adj), batched_test_adj)

torch.set_default_dtype(default_dtype)
59 changes: 59 additions & 0 deletions tests/test_toep.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import numpy as np
import pytest
import torch
import torchkbnufft as tkbn
Expand Down Expand Up @@ -37,6 +38,64 @@ def test_toeplitz_nufft_accuracy(shape, kdata_shape, is_complex):
toep_ob = tkbn.ToepNufft()

kernel = tkbn.calc_toeplitz_kernel(ktraj, im_size, norm="ortho")
if not is_complex:
kernel = torch.view_as_real(kernel)

fbn = adj_ob(
forw_ob(image, ktraj, smaps=smaps, norm="ortho"),
ktraj,
smaps=smaps,
norm="ortho",
)
fbt = toep_ob(image, kernel, smaps=smaps, norm="ortho")

if is_complex:
fbn = torch.view_as_real(fbn)
fbt = torch.view_as_real(fbt)

norm_diff = torch.norm(fbn - fbt) / torch.norm(fbn)

assert norm_diff < norm_diff_tol

torch.set_default_dtype(default_dtype)


@pytest.mark.parametrize(
"shape, kdata_shape, is_complex",
[
([4, 3, 19], [4, 3, 25], True),
([3, 5, 13, 2], [3, 5, 18, 2], False),
([2, 4, 32, 16], [2, 4, 83], True),
([5, 8, 15, 12, 2], [5, 8, 83, 2], False),
([3, 10, 13, 18, 12], [3, 10, 112], True),
([2, 12, 17, 19, 12, 2], [2, 12, 112, 2], False),
],
)
def test_batched_toeplitz_nufft_accuracy(shape, kdata_shape, is_complex):
norm_diff_tol = 1e-4 # toeplitz is only approximate
default_dtype = torch.get_default_dtype()
torch.set_default_dtype(torch.double)
torch.manual_seed(123)
if is_complex:
im_size = shape[2:]
else:
im_size = shape[2:-1]
im_shape = [s for s in shape]
im_shape[1] = 1

image = create_input_plus_noise(im_shape, is_complex)
smaps = create_input_plus_noise(shape, is_complex)
ktraj = (
torch.rand(size=(shape[0], len(im_size), kdata_shape[2])) * 2 * np.pi - np.pi
)

forw_ob = tkbn.KbNufft(im_size=im_size)
adj_ob = tkbn.KbNufftAdjoint(im_size=im_size)
toep_ob = tkbn.ToepNufft()

kernel = tkbn.calc_toeplitz_kernel(ktraj, im_size, norm="ortho")
if not is_complex:
kernel = torch.view_as_real(kernel)

fbn = adj_ob(
forw_ob(image, ktraj, smaps=smaps, norm="ortho"),
Expand Down
2 changes: 1 addition & 1 deletion torchkbnufft/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Package info"""

__version__ = "1.0.1"
__version__ = "1.1.0"
__author__ = "Matthew Muckley"
__author_email__ = "[email protected]"
__license__ = "MIT"
Expand Down
24 changes: 18 additions & 6 deletions torchkbnufft/_nufft/dcomp.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@ def calc_density_compensation_function(
This function has optional parameters for initializing a NUFFT object. See
:py:class:`~torchkbnufft.KbInterp` for details.
* :attr:`ktraj` should be of size ``(len(im_size), klength)``,
where ``klength`` is the length of the k-space trajectory.
* :attr:`ktraj` should be of size ``(len(grid_size), klength)`` or
``(N, len(grid_size), klength)``, where ``klength`` is the length of the
k-space trajectory.
Based on the `method of Pipe
<https://doi.org/10.1002/(SICI)1522-2594(199901)41:1%3C179::AID-MRM25%3E3.0.CO;2-V>`_.
Expand Down Expand Up @@ -56,6 +57,16 @@ def calc_density_compensation_function(
>>> image = adjkb_ob(data * dcomp, omega)
"""
device = ktraj.device
batch_size = 1

if ktraj.ndim not in (2, 3):
raise ValueError("ktraj must have 2 or 3 dimensions")

if ktraj.ndim == 3:
if ktraj.shape[0] == 1:
ktraj = ktraj[0]
else:
batch_size = ktraj.shape[0]

# init nufft variables
(
Expand All @@ -80,10 +91,12 @@ def calc_density_compensation_function(
device=device,
)

test_sig = torch.ones([1, 1, ktraj.shape[-1]], dtype=tables[0].dtype, device=device)
test_sig = torch.ones(
[batch_size, 1, ktraj.shape[-1]], dtype=tables[0].dtype, device=device
)
for _ in range(num_iterations):
new_sig = tkbnF.kb_table_interp(
tkbnF.kb_table_interp_adjoint(
image=tkbnF.kb_table_interp_adjoint(
data=test_sig,
omega=ktraj,
tables=tables,
Expand All @@ -101,7 +114,6 @@ def calc_density_compensation_function(
offsets=offsets_t,
)

norm_new_sig = torch.abs(new_sig)
test_sig = test_sig / norm_new_sig
test_sig = test_sig / torch.abs(new_sig)

return test_sig
4 changes: 3 additions & 1 deletion torchkbnufft/_nufft/fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,9 @@ def fft_filter(image: Tensor, kernel: Tensor, norm: Optional[str] = "ortho") ->
raise ValueError("Only option for norm is 'ortho'.")

im_size = torch.tensor(image.shape[2:], dtype=torch.long, device=image.device)
grid_size = torch.tensor(kernel.shape[2:], dtype=torch.long, device=image.device)
grid_size = torch.tensor(
kernel.shape[-len(image.shape[2:]) :], dtype=torch.long, device=image.device
)

# set up n-dimensional zero pad
# zero pad for oversampled nufft
Expand Down
Loading

0 comments on commit abca3f2

Please sign in to comment.