Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

KB-NUFFT hangs when used in PyTorch DataLoader with num_workers > 0 #73

Open
maartenterpstra opened this issue Oct 25, 2022 · 3 comments
Labels
bug Something isn't working

Comments

@maartenterpstra
Copy link

maartenterpstra commented Oct 25, 2022

Hi,

I'm trying to perform on-the-fly data undersampling in my PyTorch dataset. To do this, I perform a Toeplitz NUFFT in the __getitem__ function of my Dataset class. This works as expected. Now, I want to to batching, so I wrap the PyTorch Dataset in a PyTorch DataLoader. This works as expected when num_workers=0. However, when num_workers is non-zero, computation of the NUFFT seemingly enters an infinite loop.

Expected behaviour

Performing a NUFFT in parallel using multiple workers should result in undersampled images.

Observed behaviour

Sampling the dataloader results in a hanging script, seemingly entering an infinite loop.

Extra information

  • A minimally-working example of this behaviour is attached below.
  • This behaviour is observed with Torch-kbnufft version 1.3.0 and PyTorch version 1.12.
  • CUDA is not used in this example, but it also happens when the NUFFT is computed on the GPU.
  • It is not limited to a Toeplitz NUFFT but also happens with the table NUFFT.
  • Density compensation has no influence on the observed behaviour

Minimal example

import torch
from skimage.data import shepp_logan_phantom
from skimage.transform import rescale
import numpy as np
import torchkbnufft as tkbn
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader


NUM_WORKERS=0 # set to > 0 to hang

class DataUndersampler(Dataset):
    def __init__(self, undersampling_factor, spatial_scaling_factor):
        self.image = shepp_logan_phantom()
        self.image_shape = self.image.shape
        self.undersampling_factor = undersampling_factor
        self.spatial_scaling_factor = spatial_scaling_factor
        # Need the original size to determine the undersampling factor and NUFFT operators
        self.orig_size = self.image_shape
        self.image_shape = (self.image_shape[0] // self.spatial_scaling_factor, self.image_shape[1] // self.spatial_scaling_factor)
        # Create an oversampled grid
        spokelength = self.image_shape[0] * 2
        self.grid_size = (spokelength, spokelength)
        
        # Generate A LOT of spokes, pick a starting point at random
        nspokes = 2000
        
        # Sample enough spokes to achieve undersampling factor
        self.spokes_to_sample = int((self.orig_size[0] * np.pi / 2) / self.undersampling_factor)
        # Generate a golden angle radial trajectory
        ga = np.deg2rad(180 / ((1 + np.sqrt(5)) / 2))
        kx = np.zeros(shape=(spokelength, nspokes))
        ky = np.zeros(shape=(spokelength, nspokes))
        ky[:, 0] = np.linspace(-np.pi, np.pi, spokelength)
        for i in range(1, nspokes):
            kx[:, i] = np.cos(ga) * kx[:, i - 1] - np.sin(ga) * ky[:, i - 1]
            ky[:, i] = np.sin(ga) * kx[:, i - 1] + np.cos(ga) * ky[:, i - 1]
            
        self.ky = np.transpose(ky)
        self.kx = np.transpose(kx)
        # 1D Ramlak. Needed for density compensation. Density is a linear function 
        # depending on the distance of the center of k-space...
        ram_lak = np.abs(np.linspace(-1, 1, spokelength + 1))
        ram_lak = ram_lak[:-1]
        
        #... except for the center, we know exactly how often we sample that,
        # namely as many times as the number of spokes
        middle_idx = len(ram_lak) // 2
        ram_lak[middle_idx] = 1/(2 * self.spokes_to_sample)
        self.ram_lak = ram_lak
    def __len__(self):
        return 1
    def __getitem__(self, index):
        if self.spatial_scaling_factor > 1:
            img = rescale(self.image, 1/self.spatial_scaling_factor).astype(np.complex)
        else:
            img = self.image.astype(np.complex)
            
        if self.undersampling_factor > 1:
            img_tensor = torch.from_numpy(img).unsqueeze(0).unsqueeze(0)

            toep_ob = tkbn.ToepNufft()
            # We pick a random starting spoke
            offset = np.random.choice(range(self.ky.shape[0]-self.spokes_to_sample))

            # And select as many subsequent spokes to reach the desired undersampling factor
            # todo: use continuing trajectories for cine?
            selected_ky = self.ky[offset:offset+self.spokes_to_sample].flatten()
            selected_kx = self.kx[offset:offset+self.spokes_to_sample].flatten()

            ktraj = torch.tensor(np.stack((selected_ky, selected_kx), axis=0))
            # Repeat and reshape the ram-lak so every spoke is density-compensated
            ram_lak_t = torch.from_numpy(np.tile(self.ram_lak, self.spokes_to_sample)).unsqueeze(0).unsqueeze(0)
            # Calculate the really efficient Toeplitz kernel to compute the NUFFT
            dcomp_kernel = tkbn.calc_toeplitz_kernel(ktraj, self.image_shape, weights=ram_lak_t, norm='ortho',numpoints=(3,3))  # with density compensation
            # And in a single step, compute the radial k-space and back to image space
            img = toep_ob(img_tensor, dcomp_kernel, norm='ortho').abs().squeeze().numpy()
            # renormalize the output, because undersampling can change this.
            img /= np.max(img)
            
        return np.abs(img)

# Create dataset
dset = DataUndersampler(
            undersampling_factor=2, 
            spatial_scaling_factor=1)


# this works fine
undersampled_img = dset[0]

# From here, the observed behaviour emerges. 

dloader = DataLoader(dset,
            shuffle=False,
            batch_size=1, num_workers=NUM_WORKERS)

# this statement hangs when num_workers > 0
undersampled_img = next(iter(dloader))

@mmuckley
Copy link
Owner

Hello @maartenterpstra, I think the issue is due to the table-based NUFFT, which is used inside tkbn.calc_toeplitz_kernel.

Does every sample have a different trajectory? If they're all the same, you could apply NUFFT outside the dataloader.

@maartenterpstra
Copy link
Author

Hi @mmuckley. I was also thinking that as a workaround I could compute the NUFFT for a single batch outside the dataloader. In general, every sample has a different trajectory but the same number of spokes. Would this be possible?

@mmuckley mmuckley added the bug Something isn't working label Oct 25, 2022
@mmuckley
Copy link
Owner

Hello @maartenterpstra, it may be more efficient to loop over the list or use a batched NUFFT. The batched NUFFT is good for a large number of small NUFFTs. You can see how to use it here.

I also opened #74 as a potential enhancement with a pointer to where the code controls threading if you'd be interested in that route.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants