@@ -803,6 +803,8 @@ def fourier_gradient(image: Tensor) -> Tuple[Tensor, Tensor]:
803803 The y and x gradients.
804804 """
805805 u , v = torch .fft .fftfreq (image .shape [- 2 ]), torch .fft .fftfreq (image .shape [- 1 ])
806+ u = u .to (image .device )
807+ v = v .to (image .device )
806808 u , v = torch .meshgrid (u , v , indexing = "ij" )
807809 grad_y = torch .fft .ifft (torch .fft .fft (image , dim = - 2 ) * (2j * torch .pi ) * u , dim = - 2 )
808810 grad_x = torch .fft .ifft (torch .fft .fft (image , dim = - 1 ) * (2j * torch .pi ) * v , dim = - 1 )
@@ -950,6 +952,8 @@ def integrate_image_2d_fourier(grad_y: Tensor, grad_x: Tensor) -> Tensor:
950952 shape = grad_y .shape
951953 f = pmath .fft2_precise (grad_x + 1j * grad_y )
952954 y , x = torch .fft .fftfreq (shape [0 ]), torch .fft .fftfreq (shape [1 ])
955+ y = y .to (grad_y .device )
956+ x = x .to (grad_y .device )
953957
954958 # In PtychoShelves' get_img_int_2D.m, they set the numerator of r to be
955959 # exp(2j * pi * (x + y[:, None])) to shift it by 1 pixel. We should NOT
@@ -1413,12 +1417,14 @@ def generate_vignette_mask(
14131417 shape : tuple [int , int ],
14141418 margin : int = 20 ,
14151419 sigma : float = 1.0 ,
1416- method : Literal ["gaussian" , "linear" ] = "gaussian"
1420+ method : Literal ["gaussian" , "linear" ] = "gaussian" ,
1421+ device : Optional [torch .device ] = None ,
14171422):
14181423 """
14191424 Generate a vignette mask for an image of shape `shape`.
14201425 """
1421- mask = torch .ones (shape , device = torch .get_default_device ())
1426+ device = device or torch .get_default_device ()
1427+ mask = torch .ones (shape , device = device )
14221428 mask = vignette (mask , margin , sigma , method = method )
14231429 return mask
14241430
@@ -1469,7 +1475,7 @@ def vignette(
14691475 mask = torch .zeros (mask_shape , device = img .device )
14701476 mask_slicer = [slice (None )] * i_dim + [slice (margin , None )]
14711477 mask [* mask_slicer ] = 1.0
1472- gauss_win = torch .signal .windows .gaussian (margin // 2 , std = sigma )
1478+ gauss_win = torch .signal .windows .gaussian (margin // 2 , std = sigma , device = img . device )
14731479 gauss_win = gauss_win / torch .sum (gauss_win )
14741480 mask = convolve1d (mask , gauss_win , dim = i_dim , padding = "same" )
14751481 mask_final_slicer = [slice (None )] * i_dim + [slice (len (gauss_win ), len (gauss_win ) + margin )]
0 commit comments