You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I'm trying to perform on-the-fly data undersampling in my PyTorch dataset. To do this, I perform a Toeplitz NUFFT in the __getitem__ function of my Dataset class. This works as expected. Now, I want to to batching, so I wrap the PyTorch Dataset in a PyTorch DataLoader. This works as expected when num_workers=0. However, when num_workers is non-zero, computation of the NUFFT seemingly enters an infinite loop.
Expected behaviour
Performing a NUFFT in parallel using multiple workers should result in undersampled images.
Observed behaviour
Sampling the dataloader results in a hanging script, seemingly entering an infinite loop.
Extra information
A minimally-working example of this behaviour is attached below.
This behaviour is observed with Torch-kbnufft version 1.3.0 and PyTorch version 1.12.
CUDA is not used in this example, but it also happens when the NUFFT is computed on the GPU.
It is not limited to a Toeplitz NUFFT but also happens with the table NUFFT.
Density compensation has no influence on the observed behaviour
Minimal example
import torch
from skimage.data import shepp_logan_phantom
from skimage.transform import rescale
import numpy as np
import torchkbnufft as tkbn
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
NUM_WORKERS=0 # set to > 0 to hang
class DataUndersampler(Dataset):
def __init__(self, undersampling_factor, spatial_scaling_factor):
self.image = shepp_logan_phantom()
self.image_shape = self.image.shape
self.undersampling_factor = undersampling_factor
self.spatial_scaling_factor = spatial_scaling_factor
# Need the original size to determine the undersampling factor and NUFFT operators
self.orig_size = self.image_shape
self.image_shape = (self.image_shape[0] // self.spatial_scaling_factor, self.image_shape[1] // self.spatial_scaling_factor)
# Create an oversampled grid
spokelength = self.image_shape[0] * 2
self.grid_size = (spokelength, spokelength)
# Generate A LOT of spokes, pick a starting point at random
nspokes = 2000
# Sample enough spokes to achieve undersampling factor
self.spokes_to_sample = int((self.orig_size[0] * np.pi / 2) / self.undersampling_factor)
# Generate a golden angle radial trajectory
ga = np.deg2rad(180 / ((1 + np.sqrt(5)) / 2))
kx = np.zeros(shape=(spokelength, nspokes))
ky = np.zeros(shape=(spokelength, nspokes))
ky[:, 0] = np.linspace(-np.pi, np.pi, spokelength)
for i in range(1, nspokes):
kx[:, i] = np.cos(ga) * kx[:, i - 1] - np.sin(ga) * ky[:, i - 1]
ky[:, i] = np.sin(ga) * kx[:, i - 1] + np.cos(ga) * ky[:, i - 1]
self.ky = np.transpose(ky)
self.kx = np.transpose(kx)
# 1D Ramlak. Needed for density compensation. Density is a linear function
# depending on the distance of the center of k-space...
ram_lak = np.abs(np.linspace(-1, 1, spokelength + 1))
ram_lak = ram_lak[:-1]
#... except for the center, we know exactly how often we sample that,
# namely as many times as the number of spokes
middle_idx = len(ram_lak) // 2
ram_lak[middle_idx] = 1/(2 * self.spokes_to_sample)
self.ram_lak = ram_lak
def __len__(self):
return 1
def __getitem__(self, index):
if self.spatial_scaling_factor > 1:
img = rescale(self.image, 1/self.spatial_scaling_factor).astype(np.complex)
else:
img = self.image.astype(np.complex)
if self.undersampling_factor > 1:
img_tensor = torch.from_numpy(img).unsqueeze(0).unsqueeze(0)
toep_ob = tkbn.ToepNufft()
# We pick a random starting spoke
offset = np.random.choice(range(self.ky.shape[0]-self.spokes_to_sample))
# And select as many subsequent spokes to reach the desired undersampling factor
# todo: use continuing trajectories for cine?
selected_ky = self.ky[offset:offset+self.spokes_to_sample].flatten()
selected_kx = self.kx[offset:offset+self.spokes_to_sample].flatten()
ktraj = torch.tensor(np.stack((selected_ky, selected_kx), axis=0))
# Repeat and reshape the ram-lak so every spoke is density-compensated
ram_lak_t = torch.from_numpy(np.tile(self.ram_lak, self.spokes_to_sample)).unsqueeze(0).unsqueeze(0)
# Calculate the really efficient Toeplitz kernel to compute the NUFFT
dcomp_kernel = tkbn.calc_toeplitz_kernel(ktraj, self.image_shape, weights=ram_lak_t, norm='ortho',numpoints=(3,3)) # with density compensation
# And in a single step, compute the radial k-space and back to image space
img = toep_ob(img_tensor, dcomp_kernel, norm='ortho').abs().squeeze().numpy()
# renormalize the output, because undersampling can change this.
img /= np.max(img)
return np.abs(img)
# Create dataset
dset = DataUndersampler(
undersampling_factor=2,
spatial_scaling_factor=1)
# this works fine
undersampled_img = dset[0]
# From here, the observed behaviour emerges.
dloader = DataLoader(dset,
shuffle=False,
batch_size=1, num_workers=NUM_WORKERS)
# this statement hangs when num_workers > 0
undersampled_img = next(iter(dloader))
The text was updated successfully, but these errors were encountered:
Hi @mmuckley. I was also thinking that as a workaround I could compute the NUFFT for a single batch outside the dataloader. In general, every sample has a different trajectory but the same number of spokes. Would this be possible?
Hello @maartenterpstra, it may be more efficient to loop over the list or use a batched NUFFT. The batched NUFFT is good for a large number of small NUFFTs. You can see how to use it here.
I also opened #74 as a potential enhancement with a pointer to where the code controls threading if you'd be interested in that route.
Hi,
I'm trying to perform on-the-fly data undersampling in my PyTorch dataset. To do this, I perform a Toeplitz NUFFT in the
__getitem__
function of myDataset
class. This works as expected. Now, I want to to batching, so I wrap the PyTorchDataset
in a PyTorchDataLoader
. This works as expected whennum_workers=0
. However, whennum_workers
is non-zero, computation of the NUFFT seemingly enters an infinite loop.Expected behaviour
Performing a NUFFT in parallel using multiple workers should result in undersampled images.
Observed behaviour
Sampling the dataloader results in a hanging script, seemingly entering an infinite loop.
Extra information
Minimal example
The text was updated successfully, but these errors were encountered: