Skip to content

Commit 78d45a7

Browse files
authored
FIX: fix device assignment in phase unwrap for multi-processing (#49)
1 parent 63a89e8 commit 78d45a7

File tree

1 file changed

+9
-3
lines changed

1 file changed

+9
-3
lines changed

src/ptychi/image_proc.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -803,6 +803,8 @@ def fourier_gradient(image: Tensor) -> Tuple[Tensor, Tensor]:
803803
The y and x gradients.
804804
"""
805805
u, v = torch.fft.fftfreq(image.shape[-2]), torch.fft.fftfreq(image.shape[-1])
806+
u = u.to(image.device)
807+
v = v.to(image.device)
806808
u, v = torch.meshgrid(u, v, indexing="ij")
807809
grad_y = torch.fft.ifft(torch.fft.fft(image, dim=-2) * (2j * torch.pi) * u, dim=-2)
808810
grad_x = torch.fft.ifft(torch.fft.fft(image, dim=-1) * (2j * torch.pi) * v, dim=-1)
@@ -950,6 +952,8 @@ def integrate_image_2d_fourier(grad_y: Tensor, grad_x: Tensor) -> Tensor:
950952
shape = grad_y.shape
951953
f = pmath.fft2_precise(grad_x + 1j * grad_y)
952954
y, x = torch.fft.fftfreq(shape[0]), torch.fft.fftfreq(shape[1])
955+
y = y.to(grad_y.device)
956+
x = x.to(grad_y.device)
953957

954958
# In PtychoShelves' get_img_int_2D.m, they set the numerator of r to be
955959
# exp(2j * pi * (x + y[:, None])) to shift it by 1 pixel. We should NOT
@@ -1413,12 +1417,14 @@ def generate_vignette_mask(
14131417
shape: tuple[int, int],
14141418
margin: int = 20,
14151419
sigma: float = 1.0,
1416-
method: Literal["gaussian", "linear"] = "gaussian"
1420+
method: Literal["gaussian", "linear"] = "gaussian",
1421+
device: Optional[torch.device] = None,
14171422
):
14181423
"""
14191424
Generate a vignette mask for an image of shape `shape`.
14201425
"""
1421-
mask = torch.ones(shape, device=torch.get_default_device())
1426+
device = device or torch.get_default_device()
1427+
mask = torch.ones(shape, device=device)
14221428
mask = vignette(mask, margin, sigma, method=method)
14231429
return mask
14241430

@@ -1469,7 +1475,7 @@ def vignette(
14691475
mask = torch.zeros(mask_shape, device=img.device)
14701476
mask_slicer = [slice(None)] * i_dim + [slice(margin, None)]
14711477
mask[*mask_slicer] = 1.0
1472-
gauss_win = torch.signal.windows.gaussian(margin // 2, std=sigma)
1478+
gauss_win = torch.signal.windows.gaussian(margin // 2, std=sigma, device=img.device)
14731479
gauss_win = gauss_win / torch.sum(gauss_win)
14741480
mask = convolve1d(mask, gauss_win, dim=i_dim, padding="same")
14751481
mask_final_slicer = [slice(None)] * i_dim + [slice(len(gauss_win), len(gauss_win) + margin)]

0 commit comments

Comments
 (0)