From abca3f28177d4e5519d96917f823aa2b8f1dc1bf Mon Sep 17 00:00:00 2001 From: Matthew Muckley Date: Tue, 16 Feb 2021 16:38:38 -0500 Subject: [PATCH] Batched nufft (#24) * First batched nufft implementation * Performance improvements * Fixes for toep * Type fix * Bug fix, doc updates * Fix dcomp it num * Update batch docs * Add new docs * Change the performance tips name * Try to change list * Increment version * Fix perf tips * Code quality, remove test script * Update performance doc * Update doc * Update docs --- docs/source/index.rst | 1 + docs/source/performance.rst | 134 ++++++++++ tests/test_dcomp.py | 38 +++ tests/test_interp.py | 49 ++++ tests/test_toep.py | 59 +++++ torchkbnufft/__init__.py | 2 +- torchkbnufft/_nufft/dcomp.py | 24 +- torchkbnufft/_nufft/fft.py | 4 +- torchkbnufft/_nufft/interp.py | 427 ++++++++++++++++++++++++++---- torchkbnufft/_nufft/spmat.py | 2 + torchkbnufft/_nufft/toep.py | 71 ++++- torchkbnufft/functional/interp.py | 61 ++++- torchkbnufft/functional/nufft.py | 115 +++++--- torchkbnufft/modules/kbinterp.py | 43 ++- torchkbnufft/modules/kbnufft.py | 73 ++++- 15 files changed, 967 insertions(+), 136 deletions(-) create mode 100644 docs/source/performance.rst diff --git a/docs/source/index.rst b/docs/source/index.rst index f92c6fc..f2ee186 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -69,6 +69,7 @@ References :caption: User Guide basic + performance .. toctree:: :hidden: diff --git a/docs/source/performance.rst b/docs/source/performance.rst new file mode 100644 index 0000000..7f10606 --- /dev/null +++ b/docs/source/performance.rst @@ -0,0 +1,134 @@ +Performance Tips +================ + +:py:mod:`torchkbnufft` is primarily written for the goal of scaling parallelism within +the PyTorch framework. The performance bottleneck of the package comes from two sources: +1) advanced indexing and 2) multiplications. Multiplications are handled in a way that +scales well, but advanced indexing is not due to +`limitations with PyTorch `_. +As a result, growth in problem size that is independent of the indexing bottleneck is +handled very well by the package, such as: + +1. Scaling the batch dimension. +2. Scaling the coil dimension. + +Generally, you can just add to these dimensions and the package will perform well +without adding much compute time. If you're chasing more speed, some strategies that +might be helpful are listed below. + +Using Batched K-space Trajectories +---------------------------------- + +As of version ``1.1.0``, :py:mod:`torchkbnufft` can use batched k-space trajectories. +If you pass in a variable for ``omega`` with dimensions +``(N, length(im_size), klength)``, the package will parallelize the execution of all +trajectories in the ``N`` dimension. This is useful when ``N`` is very large, as might +occur in dynamic imaging settings. The following shows an example: + +.. code-block:: python + + import torch + import torchkbnufft as tkbn + import numpy as np + from skimage.data import shepp_logan_phantom + + batch_size = 12 + + x = shepp_logan_phantom().astype(np.complex) + im_size = x.shape + # convert to tensor, unsqueeze batch and coil dimension + # output size: (batch_size, 1, ny, nx) + x = torch.tensor(x).unsqueeze(0).unsqueeze(0).to(torch.complex64) + x = x.repeat(batch_size, 1, 1, 1) + + klength = 64 + ktraj = np.stack( + (np.zeros(64), np.linspace(-np.pi, np.pi, klength)) + ) + # convert to tensor, unsqueeze batch dimension + # output size: (batch_size, 2, klength) + ktraj = torch.tensor(ktraj).to(torch.float) + ktraj = ktraj.unsqueeze(0).repeat(batch_size, 1, 1) + + nufft_ob = tkbn.KbNufft(im_size=im_size) + # outputs a (batch_size, 1, klength) vector of k-space data + kdata = nufft_ob(x, ktraj) + +This code will then compute the 12 different radial spokes while parallelizing as much +as possible. + +Lowering the Precision +---------------------- + +A simple way to save both memory and compute time is to decrease the precision. PyTorch +normally operates at a default 32-bit floating point precision, but if you're converting +data from NumPy then you might have some data at 64-bit floating precision. To use +32-bit precision, simply do the following: + +.. code-block:: python + + image = image.to(dtype=torch.complex64) + ktraj = ktraj.to(dtype=torch.float32) + forw_ob = forw_ob.to(image) + + data = forw_ob(image, ktraj) + +The ``forw_ob.to(image)`` command will automagically determine the type for both real +and complex tensors registered as buffers under ``forw_ob``, so you should be able to +do this safely in your code. + +In many cases, the tradeoff for going from 64-bit to 32-bit is not severe, so you can +securely use 32-bit precision. + +Lowering the Oversampling Ratio +------------------------------- + +If you create a :py:class:`~torchkbnufft.KbNufft` object using the following code: + +.. code-block:: python + + forw_ob = tkbn.KbNufft(im_size=im_size) + +then by default it will use a 2-factor oversampled grid. For some applications, this can +be overkill. If you can sacrifice some accuracy for your application, you can use a +smaller grid with 1.25-factor oversampling by altering how you initialize NUFFT objects +like :py:class:`~torchkbnufft.KbNufft`: + +.. code-block:: python + + grid_size = tuple([int(el * 1.25) for el in im_size]) + forw_ob = tkbn.KbNufft(im_size=im_size, grid_size=grid_size) + +Using Fewer Interpolation Neighbors +----------------------------------- + +Another major speed factor is how many neighbors you use for interpolation. By default, +:py:mod:`torchkbnufft` uses 6 nearest neighbors in each dimension. If you can sacrifice +accuracy, you can get more speed by using fewer neighbors by altering how you initialize +NUFFT objects like :py:class:`~torchkbnufft.KbNufft`: + +.. code-block:: python + + forw_ob = tkbn.KbNufft(im_size=im_size, numpoints=4) + +If you know that you can be less accurate in one dimension (e.g., the z-dimension), then +you can use less neighbors in only that dimension: + +.. code-block:: python + + forw_ob = tkbn.KbNufft(im_size=im_size, numpoints=(4, 6, 6)) + +Package Limitations +------------------- + +As mentioned earlier, batches and coils scale well, primarily due to the fact that they +don't impact the bottlenecks of the package around advanced indexing. Where +:py:mod:`torchkbnufft` does not scale well is: + +1. Very long k-space trajectories. +2. More imaging dimensions (e.g., 3D). + +For these settings, you can first try to use some of the strategies here (lowering +precision, fewer neighbors, smaller grid). In some cases, lowering the precision a bit +and using a GPU can still give strong performance. If you're still waiting too long for +compute after trying all of these, you may be running into the limits of the package. diff --git a/tests/test_dcomp.py b/tests/test_dcomp.py index 8fbdbec..7cb54f6 100644 --- a/tests/test_dcomp.py +++ b/tests/test_dcomp.py @@ -1,3 +1,4 @@ +import numpy as np import pytest import torch import torchkbnufft as tkbn @@ -37,3 +38,40 @@ def test_dcomp_run(shape, kdata_shape, is_complex): _ = adj_ob(kdata * dcomp, ktraj) torch.set_default_dtype(default_dtype) + + +@pytest.mark.parametrize( + "shape, kdata_shape", + [ + ([2, 1, 19], [2, 1, 25]), + ([3, 1, 13], [3, 1, 18]), + ([6, 1, 32, 16], [6, 1, 83]), + ([5, 1, 15, 12], [5, 1, 83]), + ([3, 2, 13, 18, 12], [3, 2, 112]), + ([2, 2, 17, 19, 12], [2, 2, 112]), + ], +) +def test_batched_dcomp(shape, kdata_shape): + default_dtype = torch.get_default_dtype() + torch.set_default_dtype(torch.double) + torch.manual_seed(123) + im_size = shape[2:] + + ktraj = ( + torch.rand(size=(shape[0], len(im_size), kdata_shape[2])) * 2 * np.pi - np.pi + ) + + forloop_dcomp = [] + for ktraj_it in ktraj: + res = tkbn.calc_density_compensation_function(ktraj=ktraj_it, im_size=im_size) + forloop_dcomp.append( + tkbn.calc_density_compensation_function(ktraj=ktraj_it, im_size=im_size) + ) + + batched_dcomp = tkbn.calc_density_compensation_function( + ktraj=ktraj, im_size=im_size + ) + + assert torch.allclose(torch.cat(forloop_dcomp), batched_dcomp) + + torch.set_default_dtype(default_dtype) diff --git a/tests/test_interp.py b/tests/test_interp.py index e8ccf1e..dfc52f8 100644 --- a/tests/test_interp.py +++ b/tests/test_interp.py @@ -1,5 +1,6 @@ import pickle +import numpy as np import pytest import torch import torchkbnufft as tkbn @@ -277,3 +278,51 @@ def test_interp_autograd_gpu(shape, kdata_shape, is_complex): nufft_autograd_test(image, kdata, ktraj, forw_ob, adj_ob, spmat) torch.set_default_dtype(default_dtype) + + +@pytest.mark.parametrize( + "shape, kdata_shape, is_complex", + [ + ([3, 1, 19], [3, 1, 25], True), + ([3, 1, 13, 2], [3, 1, 18, 2], False), + ([4, 1, 32, 16], [4, 1, 83], True), + ([5, 1, 15, 12, 2], [5, 1, 83, 2], False), + ([3, 2, 13, 18, 12], [3, 2, 112], True), + ([2, 2, 17, 19, 12, 2], [2, 2, 112, 2], False), + ], +) +def test_interp_batches(shape, kdata_shape, is_complex): + default_dtype = torch.get_default_dtype() + torch.set_default_dtype(torch.double) + torch.manual_seed(123) + if is_complex: + im_size = shape[2:] + else: + im_size = shape[2:-1] + + image = create_input_plus_noise(shape, is_complex) + kdata = create_input_plus_noise(kdata_shape, is_complex) + ktraj = ( + torch.rand(size=(shape[0], len(im_size), kdata_shape[2])) * 2 * np.pi - np.pi + ) + + forw_ob = tkbn.KbInterp(im_size=im_size, grid_size=im_size) + adj_ob = tkbn.KbInterpAdjoint(im_size=im_size, grid_size=im_size) + + forloop_test_forw = [] + for image_it, ktraj_it in zip(image, ktraj): + forloop_test_forw.append(forw_ob(image_it.unsqueeze(0), ktraj_it)) + + batched_test_forw = forw_ob(image, ktraj) + + assert torch.allclose(torch.cat(forloop_test_forw), batched_test_forw) + + forloop_test_adj = [] + for data_it, ktraj_it in zip(kdata, ktraj): + forloop_test_adj.append(adj_ob(data_it.unsqueeze(0), ktraj_it)) + + batched_test_adj = adj_ob(kdata, ktraj) + + assert torch.allclose(torch.cat(forloop_test_adj), batched_test_adj) + + torch.set_default_dtype(default_dtype) diff --git a/tests/test_toep.py b/tests/test_toep.py index df914bd..c154856 100644 --- a/tests/test_toep.py +++ b/tests/test_toep.py @@ -1,3 +1,4 @@ +import numpy as np import pytest import torch import torchkbnufft as tkbn @@ -37,6 +38,64 @@ def test_toeplitz_nufft_accuracy(shape, kdata_shape, is_complex): toep_ob = tkbn.ToepNufft() kernel = tkbn.calc_toeplitz_kernel(ktraj, im_size, norm="ortho") + if not is_complex: + kernel = torch.view_as_real(kernel) + + fbn = adj_ob( + forw_ob(image, ktraj, smaps=smaps, norm="ortho"), + ktraj, + smaps=smaps, + norm="ortho", + ) + fbt = toep_ob(image, kernel, smaps=smaps, norm="ortho") + + if is_complex: + fbn = torch.view_as_real(fbn) + fbt = torch.view_as_real(fbt) + + norm_diff = torch.norm(fbn - fbt) / torch.norm(fbn) + + assert norm_diff < norm_diff_tol + + torch.set_default_dtype(default_dtype) + + +@pytest.mark.parametrize( + "shape, kdata_shape, is_complex", + [ + ([4, 3, 19], [4, 3, 25], True), + ([3, 5, 13, 2], [3, 5, 18, 2], False), + ([2, 4, 32, 16], [2, 4, 83], True), + ([5, 8, 15, 12, 2], [5, 8, 83, 2], False), + ([3, 10, 13, 18, 12], [3, 10, 112], True), + ([2, 12, 17, 19, 12, 2], [2, 12, 112, 2], False), + ], +) +def test_batched_toeplitz_nufft_accuracy(shape, kdata_shape, is_complex): + norm_diff_tol = 1e-4 # toeplitz is only approximate + default_dtype = torch.get_default_dtype() + torch.set_default_dtype(torch.double) + torch.manual_seed(123) + if is_complex: + im_size = shape[2:] + else: + im_size = shape[2:-1] + im_shape = [s for s in shape] + im_shape[1] = 1 + + image = create_input_plus_noise(im_shape, is_complex) + smaps = create_input_plus_noise(shape, is_complex) + ktraj = ( + torch.rand(size=(shape[0], len(im_size), kdata_shape[2])) * 2 * np.pi - np.pi + ) + + forw_ob = tkbn.KbNufft(im_size=im_size) + adj_ob = tkbn.KbNufftAdjoint(im_size=im_size) + toep_ob = tkbn.ToepNufft() + + kernel = tkbn.calc_toeplitz_kernel(ktraj, im_size, norm="ortho") + if not is_complex: + kernel = torch.view_as_real(kernel) fbn = adj_ob( forw_ob(image, ktraj, smaps=smaps, norm="ortho"), diff --git a/torchkbnufft/__init__.py b/torchkbnufft/__init__.py index b431c61..0129951 100644 --- a/torchkbnufft/__init__.py +++ b/torchkbnufft/__init__.py @@ -1,6 +1,6 @@ """Package info""" -__version__ = "1.0.1" +__version__ = "1.1.0" __author__ = "Matthew Muckley" __author_email__ = "matt.muckley@gmail.com" __license__ = "MIT" diff --git a/torchkbnufft/_nufft/dcomp.py b/torchkbnufft/_nufft/dcomp.py index bd03c52..424f358 100644 --- a/torchkbnufft/_nufft/dcomp.py +++ b/torchkbnufft/_nufft/dcomp.py @@ -23,8 +23,9 @@ def calc_density_compensation_function( This function has optional parameters for initializing a NUFFT object. See :py:class:`~torchkbnufft.KbInterp` for details. - * :attr:`ktraj` should be of size ``(len(im_size), klength)``, - where ``klength`` is the length of the k-space trajectory. + * :attr:`ktraj` should be of size ``(len(grid_size), klength)`` or + ``(N, len(grid_size), klength)``, where ``klength`` is the length of the + k-space trajectory. Based on the `method of Pipe `_. @@ -56,6 +57,16 @@ def calc_density_compensation_function( >>> image = adjkb_ob(data * dcomp, omega) """ device = ktraj.device + batch_size = 1 + + if ktraj.ndim not in (2, 3): + raise ValueError("ktraj must have 2 or 3 dimensions") + + if ktraj.ndim == 3: + if ktraj.shape[0] == 1: + ktraj = ktraj[0] + else: + batch_size = ktraj.shape[0] # init nufft variables ( @@ -80,10 +91,12 @@ def calc_density_compensation_function( device=device, ) - test_sig = torch.ones([1, 1, ktraj.shape[-1]], dtype=tables[0].dtype, device=device) + test_sig = torch.ones( + [batch_size, 1, ktraj.shape[-1]], dtype=tables[0].dtype, device=device + ) for _ in range(num_iterations): new_sig = tkbnF.kb_table_interp( - tkbnF.kb_table_interp_adjoint( + image=tkbnF.kb_table_interp_adjoint( data=test_sig, omega=ktraj, tables=tables, @@ -101,7 +114,6 @@ def calc_density_compensation_function( offsets=offsets_t, ) - norm_new_sig = torch.abs(new_sig) - test_sig = test_sig / norm_new_sig + test_sig = test_sig / torch.abs(new_sig) return test_sig diff --git a/torchkbnufft/_nufft/fft.py b/torchkbnufft/_nufft/fft.py index e9b94a2..c760456 100644 --- a/torchkbnufft/_nufft/fft.py +++ b/torchkbnufft/_nufft/fft.py @@ -146,7 +146,9 @@ def fft_filter(image: Tensor, kernel: Tensor, norm: Optional[str] = "ortho") -> raise ValueError("Only option for norm is 'ortho'.") im_size = torch.tensor(image.shape[2:], dtype=torch.long, device=image.device) - grid_size = torch.tensor(kernel.shape[2:], dtype=torch.long, device=image.device) + grid_size = torch.tensor( + kernel.shape[-len(image.shape[2:]) :], dtype=torch.long, device=image.device + ) # set up n-dimensional zero pad # zero pad for oversampled nufft diff --git a/torchkbnufft/_nufft/interp.py b/torchkbnufft/_nufft/interp.py index 61bc0aa..339d85d 100644 --- a/torchkbnufft/_nufft/interp.py +++ b/torchkbnufft/_nufft/interp.py @@ -1,4 +1,3 @@ -import math from typing import List, Tuple, Union import numpy as np @@ -86,6 +85,27 @@ def spmat_interp_adjoint( return torch.view_as_complex(image).reshape(*output_size) +@torch.jit.script +def calc_split_sizes( + length: int, + num_splits: int, +) -> List[int]: + """Same as np.array_split for PyTorch.""" + # TODO: replace all calls of this function with split_tensor in PyTorch 1.8 + size1 = length // num_splits + 1 + num_size1 = length % num_splits + size2 = length // num_splits + + split_sizes: List[int] = [] + for i in range(num_splits): + if i < num_size1: + split_sizes.append(size1) + else: + split_sizes.append(size2) + + return split_sizes + + @torch.jit.script def calc_coef_and_indices( tm: Tensor, @@ -175,7 +195,7 @@ def table_interp_one_batch( centers = torch.floor(numpoints * table_oversamp / 2).to(dtype=int_type) # offset from k-space to first coef loc - base_offset = 1 + torch.floor(tm - numpoints.unsqueeze(1) / 2.0).to(dtype=int_type) + base_offset = 1 + torch.floor(tm - numpoints.unsqueeze(-1) / 2.0).to(dtype=int_type) # flatten image dimensions image = image.reshape(image.shape[0], image.shape[1], -1) @@ -199,13 +219,79 @@ def table_interp_one_batch( # phase for fftshift return kdat * imag_exp( - torch.mv(torch.transpose(omega, 1, 0), n_shift), + torch.sum(omega * n_shift.unsqueeze(-1), dim=-2, keepdim=True), return_complex=True, ) @torch.jit.script -def table_interp_over_batches( +def table_interp_multiple_batches( + image: Tensor, + omega: Tensor, + tables: List[Tensor], + n_shift: Tensor, + numpoints: Tensor, + table_oversamp: Tensor, + offsets: Tensor, +) -> Tensor: + """Table interpolation with for loop over batch dimension.""" + kdat = [] + for (it_image, it_omega) in zip(image, omega): + kdat.append( + table_interp_one_batch( + it_image.unsqueeze(0), + it_omega, + tables, + n_shift, + numpoints, + table_oversamp, + offsets, + ) + ) + + return torch.cat(kdat) + + +@torch.jit.script +def table_interp_fork_over_batchdim( + image: Tensor, + omega: Tensor, + tables: List[Tensor], + n_shift: Tensor, + numpoints: Tensor, + table_oversamp: Tensor, + offsets: Tensor, + num_forks: int, +) -> Tensor: + """Table interpolation with forking over k-space.""" + + # indexing is worst when we have repeated indices - let's spread them out + split_sizes = calc_split_sizes(omega.shape[0], num_forks) + + # initialize the fork processes + futures: List[torch.jit.Future[torch.Tensor]] = [] + for (image_chunk, omega_chunk) in zip( + image.split(split_sizes), omega.split(split_sizes) + ): + futures.append( + torch.jit.fork( + table_interp_multiple_batches, + image_chunk, + omega_chunk, + tables, + n_shift, + numpoints, + table_oversamp, + offsets, + ) + ) + + # collect the results + return torch.cat([torch.jit.wait(future) for future in futures]) + + +@torch.jit.script +def table_interp_fork_over_kspace( image: Tensor, omega: Tensor, tables: List[Tensor], @@ -221,6 +307,7 @@ def table_interp_over_batches( klength = omega.shape[1] omega_chunks = [omega[:, ind:klength:num_forks] for ind in range(num_forks)] + # initialize the fork processes futures: List[torch.jit.Future[torch.Tensor]] = [] for omega_chunk in omega_chunks: futures.append( @@ -244,6 +331,7 @@ def table_interp_over_batches( device=image.device, ) + # collect the results for ind, future in enumerate(futures): kdat[:, :, ind:klength:num_forks] = torch.jit.wait(future) @@ -275,33 +363,65 @@ def table_interp( table_oversamp: Size of table in each dimension. offsets: A list of offset values for interpolation. min_kspace_per_fork: Minimum number of k-space samples to use in each - process fork. + process fork. Only used for single trajectory on CPU. Returns: ``image`` interpolated to k-space locations at ``omega``. """ - if image.device == torch.device("cpu"): - # we fork processes for indexing, so we need to do a bit of thread management - # for OMP to make sure we don't oversubscribe (managment not necessary for non-OMP) - num_threads = torch.get_num_threads() - factors = torch.arange(1, math.sqrt(num_threads)) - factors = factors[torch.remainder(torch.tensor(num_threads), factors) == 0] - threads_per_fork = 1 - for factor in factors: - # minimum k-space points per fork - if num_threads / factor <= omega.shape[1] / min_kspace_per_fork: + if omega.ndim not in (2, 3): + raise ValueError("omega must have 2 or 3 dimensions.") + + if omega.ndim == 3: + if omega.shape[0] == 1: + omega = omega[0] # broadcast a single traj + + if omega.ndim == 3: + if not omega.shape[0] == image.shape[0]: + raise ValueError( + "If omega has batch dim, omega batch dimension must match image." + ) + + # we fork processes for accumulation, so we need to do a bit of thread + # management for OMP to make sure we don't oversubscribe (managment not + # necessary for non-OMP) + num_threads = torch.get_num_threads() + factors = torch.arange(1, num_threads + 1) + factors = factors[torch.remainder(torch.tensor(num_threads), factors) == 0] + threads_per_fork = num_threads # default fallback + + if omega.ndim == 3: + # increase number of forks as long as it's not greater than batch size + for factor in factors.flip(0): + if num_threads // factor <= omega.shape[0]: threads_per_fork = int(factor) - break + + num_forks = num_threads // threads_per_fork + + if USING_OMP and image.device == torch.device("cpu"): + torch.set_num_threads(threads_per_fork) + kdat = table_interp_fork_over_batchdim( + image, omega, tables, n_shift, numpoints, table_oversamp, offsets, num_forks + ) + if USING_OMP and image.device == torch.device("cpu"): + torch.set_num_threads(num_threads) + elif image.device == torch.device("cpu"): + # determine number of process forks while keeping a minimum amount of + # k-space per fork + for factor in factors.flip(0): + if omega.shape[1] / (num_threads // factor) >= min_kspace_per_fork: + threads_per_fork = int(factor) + num_forks = num_threads // threads_per_fork if USING_OMP: torch.set_num_threads(threads_per_fork) - kdat = table_interp_over_batches( + kdat = table_interp_fork_over_kspace( image, omega, tables, n_shift, numpoints, table_oversamp, offsets, num_forks ) if USING_OMP: torch.set_num_threads(num_threads) else: + # no forking for batchless omega on GPU kdat = table_interp_one_batch( image, omega, tables, n_shift, numpoints, table_oversamp, offsets ) @@ -309,53 +429,194 @@ def table_interp( return kdat +@torch.jit.script def accum_tensor_index_add(image: Tensor, arr_ind: Tensor, data: Tensor) -> Tensor: """We fork this function for the adjoint accumulation.""" - return image.index_add_(0, arr_ind, data) + if arr_ind.ndim == 2: + for (image_batch, arr_ind_batch, data_batch) in zip(image, arr_ind, data): + for (image_coil, data_coil) in zip(image_batch, data_batch): + image_coil.index_add_(0, arr_ind_batch, data_coil) + else: + for (image_it, data_it) in zip(image, data): + image_it.index_add_(0, arr_ind, data_it) + + return image +@torch.jit.script def accum_tensor_index_put(image: Tensor, arr_ind: Tensor, data: Tensor) -> Tensor: """We fork this function for the adjoint accumulation.""" - return image.index_put_((arr_ind,), data, accumulate=True) + if arr_ind.ndim == 2: + for (image_batch, arr_ind_batch, data_batch) in zip(image, arr_ind, data): + for (image_coil, data_coil) in zip(image_batch, data_batch): + image_coil.index_put_((arr_ind_batch,), data_coil, accumulate=True) + else: + for (image_it, data_it) in zip(image, data): + image_it.index_put_((arr_ind,), data_it, accumulate=True) + + return image @torch.jit.script -def fork_and_accum(image: Tensor, arr_ind: Tensor, data: Tensor, num_forks: int): +def fork_and_accum( + image: Tensor, arr_ind: Tensor, data: Tensor, num_forks: int +) -> Tensor: """Process forking and per batch/coil accumulation function.""" device = image.device + # divide the work + split_sizes = calc_split_sizes(image.shape[0], num_forks) + + # initialize the fork processes futures: List[torch.jit.Future[torch.Tensor]] = [] - for batch_ind in range(image.shape[0]): - for coil_ind in range(image.shape[1]): - # if we've used all our forks, wait for one to finish and pop - if len(futures) == num_forks: - torch.jit.wait(futures[0]) - futures.pop(0) - - # one of these is faster on cpu, other is faster on gpu + if arr_ind.ndim == 2: + for (image_chunk, arr_ind_chunk, data_chunk) in zip( + image.split(split_sizes), + arr_ind.split(split_sizes), + data.split(split_sizes), + ): + if device == torch.device("cpu"): + futures.append( + torch.jit.fork( + accum_tensor_index_put, + image_chunk, + arr_ind_chunk, + data_chunk, + ) + ) + else: + futures.append( + torch.jit.fork( + accum_tensor_index_add, + image_chunk, + arr_ind_chunk, + data_chunk, + ) + ) + else: + for (image_chunk, data_chunk) in zip( + image.split(split_sizes), data.split(split_sizes) + ): if device == torch.device("cpu"): futures.append( torch.jit.fork( accum_tensor_index_put, - image[batch_ind, coil_ind], + image_chunk, arr_ind, - data[batch_ind, coil_ind], + data_chunk, ) ) else: futures.append( torch.jit.fork( accum_tensor_index_add, - image[batch_ind, coil_ind], + image_chunk, arr_ind, - data[batch_ind, coil_ind], + data_chunk, ) ) + + # wait for processes to finish + # results in-place _ = [torch.jit.wait(future) for future in futures] + return image + @torch.jit.script -def sort_data( +def calc_coef_and_indices_batch( + tm: Tensor, + base_offset: Tensor, + offset_increments: Tensor, + tables: List[Tensor], + centers: Tensor, + table_oversamp: Tensor, + grid_size: Tensor, + conjcoef: bool, +) -> Tuple[Tensor, Tensor]: + """For loop coef calculation over batch dim.""" + coef = [] + arr_ind = [] + for (tm_it, base_offset_it) in zip(tm, base_offset): + coef_it, arr_ind_it = calc_coef_and_indices( + tm=tm_it, + base_offset=base_offset_it, + offset_increments=offset_increments, + tables=tables, + centers=centers, + table_oversamp=table_oversamp, + grid_size=grid_size, + conjcoef=conjcoef, + ) + + coef.append(coef_it) + arr_ind.append(arr_ind_it) + + return (torch.stack(coef), torch.stack(arr_ind)) + + +@torch.jit.script +def calc_coef_and_indices_fork_over_batches( + tm: Tensor, + base_offset: Tensor, + offset_increments: Tensor, + tables: List[Tensor], + centers: Tensor, + table_oversamp: Tensor, + grid_size: Tensor, + conjcoef: bool, + num_forks: int, +) -> Tuple[Tensor, Tensor]: + """Split work across batchdim, fork processes.""" + if tm.ndim == 3: + if tm.shape[0] == 1: + tm = tm[0] + + if tm.ndim == 2: + coef, arr_ind = calc_coef_and_indices( + tm=tm, + base_offset=base_offset, + offset_increments=offset_increments, + tables=tables, + centers=centers, + table_oversamp=table_oversamp, + grid_size=grid_size, + conjcoef=conjcoef, + ) + else: + # divide the work + split_sizes = calc_split_sizes(tm.shape[0], num_forks) + + # initialize the fork processes + futures: List[torch.jit.Future[Tuple[Tensor, Tensor]]] = [] + for (tm_chunk, base_offset_chunk) in zip( + tm.split(split_sizes), + base_offset.split(split_sizes), + ): + futures.append( + torch.jit.fork( + calc_coef_and_indices_batch, + tm_chunk, + base_offset_chunk, + offset_increments, + tables, + centers, + table_oversamp, + grid_size, + conjcoef, + ) + ) + + # collect the results + results = [torch.jit.wait(future) for future in futures] + coef = torch.cat([result[0] for result in results]) + arr_ind = torch.cat([result[1] for result in results]) + + return coef, arr_ind + + +@torch.jit.script +def sort_one_batch( tm: Tensor, omega: Tensor, data: Tensor, grid_size: Tensor ) -> Tuple[Tensor, Tensor, Tensor]: """Sort input tensors by ordered values of tm.""" @@ -368,6 +629,28 @@ def sort_data( return tm[:, indices], omega[:, indices], data[:, :, indices] +@torch.jit.script +def sort_data( + tm: Tensor, omega: Tensor, data: Tensor, grid_size: Tensor +) -> Tuple[Tensor, Tensor, Tensor]: + """Sort input tensors by ordered values of tm.""" + if omega.ndim == 3: + # loop over batch dimension to get sorted k-space + results: List[Tuple[Tensor, Tensor, Tensor]] = [] + for (tm_it, omega_it, data_it) in zip(tm, omega, data): + results.append( + sort_one_batch(tm_it, omega_it, data_it.unsqueeze(0), grid_size) + ) + + tm_ret = torch.stack([result[0] for result in results]) + omega_ret = torch.stack([result[1] for result in results]) + data_ret = torch.cat([result[2] for result in results]) + else: + tm_ret, omega_ret, data_ret = sort_one_batch(tm, omega, data, grid_size) + + return tm_ret, omega_ret, data_ret + + def table_interp_adjoint( data: Tensor, omega: Tensor, @@ -392,26 +675,47 @@ def table_interp_adjoint( numpoints: Number of neighbors in each dimension. table_oversamp: Size of table in each dimension. offsets: A list of offset values for interpolation. - min_kspace_per_fork: Minimum number of k-space samples to use in each - process fork. + grid_size: Size of grid to interpolate to. Returns: ``data`` interpolated to gridded locations. """ + if omega.ndim not in (2, 3): + raise ValueError("omega must have 2 or 3 dimensions.") + + if omega.ndim == 3: + if omega.shape[0] == 1: + omega = omega[0] # broadcast a single traj + + if omega.ndim == 3: + if not omega.shape[0] == data.shape[0]: + raise ValueError( + "If omega has batch dim, omega batch dimension must match data." + ) + dtype = data.dtype device = data.device int_type = torch.long - # we fork processes for accumulation, so we need to do a bit of thread management - # for OMP to make sure we don't oversubscribe (managment not necessary for non-OMP) + # we fork processes for accumulation, so we need to do a bit of thread + # management for OMP to make sure we don't oversubscribe (managment not + # necessary for non-OMP) num_threads = torch.get_num_threads() - factors = torch.arange(1, math.sqrt(num_threads)).flip(0) + factors = torch.arange(1, num_threads + 1) factors = factors[torch.remainder(torch.tensor(num_threads), factors) == 0] - threads_per_fork = 1 - for factor in factors: - if factor <= num_threads / (data.shape[0] * data.shape[1]): - threads_per_fork = int(factor) - break + threads_per_fork = num_threads # default fallback + + if omega.ndim == 3: + # increase number of forks as long as it's not greater than batch size + for factor in factors.flip(0): + if num_threads // factor <= omega.shape[0]: + threads_per_fork = int(factor) + else: + # increase forks as long as it's less/eq than batch * coildim + for factor in factors.flip(0): + if num_threads // factor <= data.shape[0] * data.shape[1]: + threads_per_fork = int(factor) + num_forks = num_threads // threads_per_fork # calculate output size @@ -428,7 +732,7 @@ def table_interp_adjoint( centers = torch.floor(numpoints * table_oversamp / 2).to(dtype=int_type) # offset from k-space to first coef loc - base_offset = 1 + torch.floor(tm - numpoints.unsqueeze(1) / 2.0).to(dtype=int_type) + base_offset = 1 + torch.floor(tm - numpoints.unsqueeze(-1) / 2.0).to(dtype=int_type) # initialized flattened image image = torch.zeros( @@ -441,7 +745,7 @@ def table_interp_adjoint( data = ( data * imag_exp( - torch.mv(torch.transpose(omega, 1, 0), n_shift), + torch.sum(omega * n_shift.unsqueeze(-1), dim=-2, keepdim=True), return_complex=True, ).conj() ) @@ -453,7 +757,9 @@ def table_interp_adjoint( # loop over offsets and take advantage of broadcasting for offset in offsets: - coef, arr_ind = calc_coef_and_indices( + if USING_OMP and device == torch.device("cpu") and tm.ndim == 3: + torch.set_num_threads(threads_per_fork) + coef, arr_ind = calc_coef_and_indices_fork_over_batches( tm=tm, base_offset=base_offset, offset_increments=offset, @@ -462,9 +768,16 @@ def table_interp_adjoint( table_oversamp=table_oversamp, grid_size=grid_size, conjcoef=True, + num_forks=num_forks, ) + if USING_OMP and device == torch.device("cpu") and tm.ndim == 3: + torch.set_num_threads(num_threads) + + # multiply coefs to data + if coef.ndim == 2: + coef = coef.unsqueeze(1) + assert coef.ndim == data.ndim - # we have to fork this multiply ourselves tmp = coef * data if not device == torch.device("cpu"): @@ -473,7 +786,25 @@ def table_interp_adjoint( if USING_OMP and device == torch.device("cpu"): torch.set_num_threads(threads_per_fork) # this is a much faster way of doing index accumulation - fork_and_accum(image, arr_ind, tmp, num_forks) + if arr_ind.ndim == 1: + # fork over coils and batches + if device == torch.device("cpu"): + image = fork_and_accum( + image.view(data.shape[0] * data.shape[1], output_prod), + arr_ind, + tmp.view(data.shape[0] * data.shape[1], -1), + num_forks, + ).view(data.shape[0], data.shape[1], output_prod) + else: + image = fork_and_accum( + image.view(data.shape[0] * data.shape[1], output_prod, 2), + arr_ind, + tmp.view(data.shape[0] * data.shape[1], -1, 2), + num_forks, + ).view(data.shape[0], data.shape[1], output_prod, 2) + else: + # fork just over batches + image = fork_and_accum(image, arr_ind, tmp, num_forks) if USING_OMP and device == torch.device("cpu"): torch.set_num_threads(num_threads) diff --git a/torchkbnufft/_nufft/spmat.py b/torchkbnufft/_nufft/spmat.py index e6f0fa9..7e379e0 100644 --- a/torchkbnufft/_nufft/spmat.py +++ b/torchkbnufft/_nufft/spmat.py @@ -56,6 +56,8 @@ def calc_tensor_spmatrix( >>> adjkb_ob = tkbn.KbNufftAdjoint(im_size=(8, 8)) >>> image = adjkb_ob(data, omega, spmats) """ + if not omega.ndim == 2: + raise ValueError("Sparse matrix calculation not implemented for batched omega.") ( im_size, grid_size, diff --git a/torchkbnufft/_nufft/toep.py b/torchkbnufft/_nufft/toep.py index 2980387..5d195ba 100644 --- a/torchkbnufft/_nufft/toep.py +++ b/torchkbnufft/_nufft/toep.py @@ -41,8 +41,9 @@ def calc_toeplitz_kernel( This function is intended to be used in conjunction with :py:class:`~torchkbnufft.ToepNufft` for forward operations. - * :attr:`omega` should be of size ``(len(im_size), klength)``, - where ``klength`` is the length of the k-space trajectory. + * :attr:`omega` should be of size ``(len(im_size), klength)`` or + ``(N, len(im_size), klength)``, where ``klength`` is the length of the + k-space trajectory. Args: omega: k-space trajectory (in radians/voxel). @@ -71,6 +72,70 @@ def calc_toeplitz_kernel( >>> kernel = tkbn.calc_toeplitz_kernel(omega, im_size=(8, 8)) >>> image = toep_ob(image, kernel) """ + if omega.ndim not in (2, 3): + raise ValueError("Unrecognized k-space shape.") + + if weights is not None: + if weights.ndim not in (2, 3): + raise ValueError("Unrecognized weights dimension.") + if omega.ndim == 3 and weights.ndim == 2: + if weights.shape[0] == 1: + weights = weights.repeat(omega.shape[0], 1) + if not weights.shape[0] == omega.shape[0]: + raise ValueError("weights and omega do not have same batch size") + + if omega.ndim == 2: + kernel = calc_one_batch_toeplitz_kernel( + omega=omega, + im_size=im_size, + weights=weights, + norm=norm, + grid_size=grid_size, + numpoints=numpoints, + n_shift=n_shift, + table_oversamp=table_oversamp, + kbwidth=kbwidth, + order=order, + ) + else: + kernel_list = [] + for i, omega_it in enumerate(omega): + if weights is None: + weights_it = None + else: + weights_it = weights[i] + kernel_list.append( + calc_one_batch_toeplitz_kernel( + omega=omega_it, + im_size=im_size, + weights=weights_it, + norm=norm, + grid_size=grid_size, + numpoints=numpoints, + n_shift=n_shift, + table_oversamp=table_oversamp, + kbwidth=kbwidth, + order=order, + ) + ) + kernel = torch.stack(kernel_list) + + return kernel + + +def calc_one_batch_toeplitz_kernel( + omega: Tensor, + im_size: Sequence[int], + weights: Optional[Tensor] = None, + norm: Optional[str] = None, + grid_size: Optional[Sequence[int]] = None, + numpoints: Union[int, Sequence[int]] = 6, + n_shift: Optional[Sequence[int]] = None, + table_oversamp: Union[int, Sequence[int]] = 2 ** 10, + kbwidth: float = 2.34, + order: Union[float, Sequence[float]] = 0.0, +) -> Tensor: + """See calc_toeplitz_kernel().""" device = omega.device normalized = True if norm == "ortho" else False @@ -108,7 +173,7 @@ def calc_toeplitz_kernel( kernel = hermitify(kernel, 2) # put the kernel in fft space - return fft_fn(kernel, omega.shape[0], normalized=normalized) + return fft_fn(kernel, omega.shape[0], normalized=normalized)[0, 0] def adjoint_flip_and_concat( diff --git a/torchkbnufft/functional/interp.py b/torchkbnufft/functional/interp.py index 6fb9886..ce003d6 100644 --- a/torchkbnufft/functional/interp.py +++ b/torchkbnufft/functional/interp.py @@ -1,5 +1,6 @@ from typing import List, Tuple +import torch from torch import Tensor from .._autograd.interp import ( @@ -27,7 +28,20 @@ def kb_spmat_interp(image: Tensor, interp_mats: Tuple[Tensor, Tensor]) -> Tensor Returns: ``image`` calculated at scattered locations. """ - return KbSpmatInterpForward.apply(image, interp_mats) + is_complex = True + if not image.is_complex(): + if not image.shape[-1] == 2: + raise ValueError("For real inputs, last dimension must be size 2.") + + is_complex = False + image = torch.view_as_complex(image) + + data = KbSpmatInterpForward.apply(image, interp_mats) + + if is_complex is False: + data = torch.view_as_real(data) + + return data def kb_spmat_interp_adjoint( @@ -49,7 +63,20 @@ def kb_spmat_interp_adjoint( Returns: ``data`` calculated at gridded locations. """ - return KbSpmatInterpAdjoint.apply(data, interp_mats, grid_size) + is_complex = True + if not data.is_complex(): + if not data.shape[-1] == 2: + raise ValueError("For real inputs, last dimension must be size 2.") + + is_complex = False + data = torch.view_as_complex(data) + + image = KbSpmatInterpAdjoint.apply(data, interp_mats, grid_size) + + if is_complex is False: + image = torch.view_as_real(image) + + return image def kb_table_interp( @@ -79,10 +106,23 @@ def kb_table_interp( Returns: ``image`` calculated at scattered locations. """ - return KbTableInterpForward.apply( + is_complex = True + if not image.is_complex(): + if not image.shape[-1] == 2: + raise ValueError("For real inputs, last dimension must be size 2.") + + is_complex = False + image = torch.view_as_complex(image) + + data = KbTableInterpForward.apply( image, omega, tables, n_shift, numpoints, table_oversamp, offsets ) + if is_complex is False: + data = torch.view_as_real(data) + + return data + def kb_table_interp_adjoint( data: Tensor, @@ -114,6 +154,19 @@ def kb_table_interp_adjoint( Returns: ``data`` calculated at gridded locations. """ - return KbTableInterpAdjoint.apply( + is_complex = True + if not data.is_complex(): + if not data.shape[-1] == 2: + raise ValueError("For real inputs, last dimension must be size 2.") + + is_complex = False + data = torch.view_as_complex(data) + + image = KbTableInterpAdjoint.apply( data, omega, tables, n_shift, numpoints, table_oversamp, offsets, grid_size ) + + if is_complex is False: + image = torch.view_as_real(image) + + return image diff --git a/torchkbnufft/functional/nufft.py b/torchkbnufft/functional/nufft.py index 9e6b3b2..297f275 100644 --- a/torchkbnufft/functional/nufft.py +++ b/torchkbnufft/functional/nufft.py @@ -1,5 +1,6 @@ from typing import List, Optional, Tuple +import torch from torch import Tensor from .._nufft.fft import fft_and_scale, ifft_and_scale @@ -42,19 +43,30 @@ def kb_spmat_nufft( Returns: ``image`` calculated at scattered Fourier locations. """ - image = fft_and_scale( - image=image, - scaling_coef=scaling_coef, - im_size=im_size, - grid_size=grid_size, - norm=norm, - ) + is_complex = True + if not image.is_complex(): + if not image.shape[-1] == 2: + raise ValueError("For real inputs, last dimension must be size 2.") + + is_complex = False + image = torch.view_as_complex(image) - return kb_spmat_interp( - image=image, + data = kb_spmat_interp( + image=fft_and_scale( + image=image, + scaling_coef=scaling_coef, + im_size=im_size, + grid_size=grid_size, + norm=norm, + ), interp_mats=interp_mats, ) + if is_complex is False: + data = torch.view_as_real(data) + + return data + def kb_spmat_nufft_adjoint( data: Tensor, @@ -87,18 +99,29 @@ def kb_spmat_nufft_adjoint( Returns: ``data`` transformed to an image. """ - data = kb_spmat_interp_adjoint( - data=data, interp_mats=interp_mats, grid_size=grid_size - ) + is_complex = True + if not data.is_complex(): + if not data.shape[-1] == 2: + raise ValueError("For real inputs, last dimension must be size 2.") + + is_complex = False + data = torch.view_as_complex(data) - return ifft_and_scale( - image=data, + image = ifft_and_scale( + image=kb_spmat_interp_adjoint( + data=data, interp_mats=interp_mats, grid_size=grid_size + ), scaling_coef=scaling_coef, im_size=im_size, grid_size=grid_size, norm=norm, ) + if is_complex is False: + image = torch.view_as_real(image) + + return image + def kb_table_nufft( image: Tensor, @@ -138,16 +161,22 @@ def kb_table_nufft( Returns: ``image`` calculated at scattered Fourier locations. """ - image = fft_and_scale( - image=image, - scaling_coef=scaling_coef, - im_size=im_size, - grid_size=grid_size, - norm=norm, - ) + is_complex = True + if not image.is_complex(): + if not image.shape[-1] == 2: + raise ValueError("For real inputs, last dimension must be size 2.") + + is_complex = False + image = torch.view_as_complex(image) - return kb_table_interp( - image=image, + data = kb_table_interp( + image=fft_and_scale( + image=image, + scaling_coef=scaling_coef, + im_size=im_size, + grid_size=grid_size, + norm=norm, + ), omega=omega, tables=tables, n_shift=n_shift, @@ -156,6 +185,11 @@ def kb_table_nufft( offsets=offsets, ) + if is_complex is False: + data = torch.view_as_real(data) + + return data + def kb_table_nufft_adjoint( data: Tensor, @@ -195,21 +229,32 @@ def kb_table_nufft_adjoint( Returns: ``data`` transformed to an image. """ - data = kb_table_interp_adjoint( - data=data, - omega=omega, - tables=tables, - n_shift=n_shift, - numpoints=numpoints, - table_oversamp=table_oversamp, - offsets=offsets, - grid_size=grid_size, - ) + is_complex = True + if not data.is_complex(): + if not data.shape[-1] == 2: + raise ValueError("For real inputs, last dimension must be size 2.") - return ifft_and_scale( - image=data, + is_complex = False + data = torch.view_as_complex(data) + + image = ifft_and_scale( + image=kb_table_interp_adjoint( + data=data, + omega=omega, + tables=tables, + n_shift=n_shift, + numpoints=numpoints, + table_oversamp=table_oversamp, + offsets=offsets, + grid_size=grid_size, + ), scaling_coef=scaling_coef, im_size=im_size, grid_size=grid_size, norm=norm, ) + + if is_complex is False: + image = torch.view_as_real(image) + + return image diff --git a/torchkbnufft/modules/kbinterp.py b/torchkbnufft/modules/kbinterp.py index dcfc268..3965922 100644 --- a/torchkbnufft/modules/kbinterp.py +++ b/torchkbnufft/modules/kbinterp.py @@ -111,8 +111,15 @@ def forward( Input tensors should be of shape ``(N, C) + grid_size``, where ``N`` is the batch size and ``C`` is the number of sensitivity coils. ``omega``, the k-space trajectory, should be of size - ``(len(grid_size), klength)``, where ``klength`` is the length of the - k-space trajectory. + ``(len(grid_size), klength)`` or ``(N, len(grid_size), klength)``, + where ``klength`` is the length of the k-space trajectory. + + Note: + + If the batch dimension is included in ``omega``, the interpolator + will parallelize over the batch dimension. This is efficient for + many small trajectories that might occur in dynamic imaging + settings. If your tensors are real-valued, ensure that 2 is the size of the last dimension. @@ -127,14 +134,6 @@ def forward( Returns: ``image`` calculated at Fourier frequencies specified by ``omega``. """ - is_complex = True - if not image.is_complex(): - if not image.shape[-1] == 2: - raise ValueError("For real inputs, last dimension must be size 2.") - - is_complex = False - image = torch.view_as_complex(image) - if interp_mats is not None: output = tkbnF.kb_spmat_interp(image=image, interp_mats=interp_mats) else: @@ -157,9 +156,6 @@ def forward( offsets=self.offsets.to(torch.long), ) - if not is_complex: - output = torch.view_as_real(output) - return output @@ -243,9 +239,17 @@ def forward( Input tensors should be of shape ``(N, C) + klength``, where ``N`` is the batch size and ``C`` is the number of sensitivity coils. ``omega``, - the k-space trajectory, should be of size ``(len(im_size), klength)``, + the k-space trajectory, should be of size + ``(len(grid_size), klength)`` or ``(N, len(grid_size), klength)``, where ``klength`` is the length of the k-space trajectory. + Note: + + If the batch dimension is included in ``omega``, the interpolator + will parallelize over the batch dimension. This is efficient for + many small trajectories that might occur in dynamic imaging + settings. + If your tensors are real-valued, ensure that 2 is the size of the last dimension. @@ -259,14 +263,6 @@ def forward( Returns: ``data`` interpolated to the grid. """ - is_complex = True - if not data.is_complex(): - if not data.shape[-1] == 2: - raise ValueError("For real inputs, last dimension must be size 2.") - - is_complex = False - data = torch.view_as_complex(data) - if grid_size is None: assert isinstance(self.grid_size, Tensor) grid_size = self.grid_size @@ -295,7 +291,4 @@ def forward( grid_size=grid_size, ) - if not is_complex: - output = torch.view_as_real(output) - return output diff --git a/torchkbnufft/modules/kbnufft.py b/torchkbnufft/modules/kbnufft.py index 229c704..541ef4f 100644 --- a/torchkbnufft/modules/kbnufft.py +++ b/torchkbnufft/modules/kbnufft.py @@ -134,8 +134,16 @@ def forward( Input tensors should be of shape ``(N, C) + im_size``, where ``N`` is the batch size and ``C`` is the number of sensitivity coils. ``omega``, - the k-space trajectory, should be of size ``(len(im_size), klength)``, - where ``klength`` is the length of the k-space trajectory. + the k-space trajectory, should be of size ``(len(grid_size), klength)`` + or ``(N, len(grid_size), klength)``, where ``klength`` is the length of + the k-space trajectory. + + Note: + + If the batch dimension is included in ``omega``, the interpolator + will parallelize over the batch dimension. This is efficient for + many small trajectories that might occur in dynamic imaging + settings. If your tensors are real, ensure that 2 is the size of the last dimension. @@ -308,8 +316,16 @@ def forward( Input tensors should be of shape ``(N, C) + klength``, where ``N`` is the batch size and ``C`` is the number of sensitivity coils. ``omega``, - the k-space trajectory, should be of size ``(len(im_size), klength)``, - where ``klength`` is the length of the k-space trajectory. + the k-space trajectory, should be of size ``(len(grid_size), klength)`` + or ``(N, len(grid_size), klength)``, where ``klength`` is the length of + the k-space trajectory. + + Note: + + If the batch dimension is included in ``omega``, the interpolator + will parallelize over the batch dimension. This is efficient for + many small trajectories that might occur in dynamic imaging + settings. If your tensors are real, ensure that 2 is the size of the last dimension. @@ -426,15 +442,31 @@ def toep_batch_loop( self, image: Tensor, smaps: Tensor, kernel: Tensor, norm: Optional[str] ) -> Tensor: output = [] - for (mini_image, smap) in zip(image, smaps): - mini_image = mini_image.unsqueeze(0) * smap.unsqueeze(0) - mini_image = tkbnF.fft_filter(image=mini_image, kernel=kernel, norm=norm) - mini_image = torch.sum( - mini_image * smap.unsqueeze(0).conj(), - dim=1, - keepdim=True, - ) - output.append(mini_image.squeeze(0)) + if len(kernel.shape) > len(image.shape[2:]): + # run with batching for kernel + for (mini_image, smap, mini_kernel) in zip(image, smaps, kernel): + mini_image = mini_image.unsqueeze(0) * smap.unsqueeze(0) + mini_image = tkbnF.fft_filter( + image=mini_image, kernel=mini_kernel, norm=norm + ) + mini_image = torch.sum( + mini_image * smap.unsqueeze(0).conj(), + dim=1, + keepdim=True, + ) + output.append(mini_image.squeeze(0)) + else: + for (mini_image, smap) in zip(image, smaps): + mini_image = mini_image.unsqueeze(0) * smap.unsqueeze(0) + mini_image = tkbnF.fft_filter( + image=mini_image, kernel=kernel, norm=norm + ) + mini_image = torch.sum( + mini_image * smap.unsqueeze(0).conj(), + dim=1, + keepdim=True, + ) + output.append(mini_image.squeeze(0)) return torch.stack(output) @@ -457,6 +489,9 @@ def forward( Returns: ``image`` after applying the Toeplitz forward/backward NUFFT. """ + if not kernel.dtype == image.dtype: + raise TypeError("kernel and image must have same dtype.") + if smaps is not None: if not smaps.dtype == image.dtype: raise TypeError("image dtype does not match smaps dtype.") @@ -465,6 +500,8 @@ def forward( if not image.is_complex(): if not image.shape[-1] == 2: raise ValueError("For real inputs, last dimension must be size 2.") + if not kernel.shape[-1] == 2: + raise ValueError("For real inputs, last dimension must be size 2.") if smaps is not None: if not smaps.shape[-1] == 2: raise ValueError("For real inputs, last dimension must be size 2.") @@ -473,6 +510,16 @@ def forward( is_complex = False image = torch.view_as_complex(image) + kernel = torch.view_as_complex(kernel) + + if len(kernel.shape) > len(image.shape[2:]): + if kernel.shape[0] == 1: + kernel = kernel[0] + elif not kernel.shape[0] == image.shape[0]: + raise ValueError( + "If using batch dimension, " + "kernel must have same batch size as image" + ) if smaps is None: output = tkbnF.fft_filter(image=image, kernel=kernel, norm=norm)