Skip to content

Commit cf45859

Browse files
authored
Merge pull request #833 from HEXRD/polar-view-speedups
Implement several speedups for polar view generation
2 parents 6ea9729 + f88a57d commit cf45859

File tree

2 files changed

+198
-82
lines changed

2 files changed

+198
-82
lines changed

hexrd/instrument/detector.py

Lines changed: 106 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
COATING_DEFAULT, FILTER_DEFAULTS, PHOSPHOR_DEFAULT
88
)
99
from hexrd.instrument.physics_package import AbstractPhysicsPackage
10+
import numba
1011
import numpy as np
1112

1213
from hexrd import constants as ct
@@ -1111,9 +1112,12 @@ def interpolate_nearest(self, xy, img, pad_with_nans=True):
11111112
int_xy[on_panel] = int_vals
11121113
return int_xy
11131114

1114-
def interpolate_bilinear(self, xy, img, pad_with_nans=True,
1115-
clip_to_panel=True,
1116-
on_panel: Optional[np.ndarray] = None):
1115+
def interpolate_bilinear(
1116+
self,
1117+
xy: np.ndarray,
1118+
img: np.ndarray,
1119+
pad_with_nans: bool = True,
1120+
):
11171121
"""
11181122
Interpolate an image array at the specified cartesian points.
11191123
@@ -1123,13 +1127,10 @@ def interpolate_bilinear(self, xy, img, pad_with_nans=True,
11231127
Array of cartesian coordinates in the image plane at which
11241128
to evaluate intensity.
11251129
img : array_like
1126-
2-dimensional image array.
1130+
2-dimensional image array. The shape must match (rows, cols).
11271131
pad_with_nans : bool, optional
11281132
Toggle for assigning NaN to points that fall off the detector.
11291133
The default is True.
1130-
on_panel : np.ndarray, optional
1131-
If you want to skip clip_to_panel() for performance reasons,
1132-
just provide an array of which pixels are on the panel.
11331134
11341135
Returns
11351136
-------
@@ -1141,28 +1142,30 @@ def interpolate_bilinear(self, xy, img, pad_with_nans=True,
11411142
-----
11421143
TODO: revisit normalization in here?
11431144
"""
1145+
fill_value = np.nan if pad_with_nans else 0
1146+
int_xy = np.full(len(xy), fill_value, dtype=float)
11441147

1145-
is_2d = img.ndim == 2
1146-
right_shape = img.shape[0] == self.rows and img.shape[1] == self.cols
1147-
assert (
1148-
is_2d and right_shape
1149-
), "input image must be 2-d with shape (%d, %d)" % (
1150-
self.rows,
1151-
self.cols,
1152-
)
1148+
# clip away points too close to or off the detector edges
1149+
xy_clip, on_panel = self.clip_to_panel(xy, buffer_edges=True)
11531150

1154-
# initialize output with nans
1155-
if pad_with_nans:
1156-
int_xy = np.nan * np.ones(len(xy))
1157-
else:
1158-
int_xy = np.zeros(len(xy))
1151+
# Generate the interpolation dict
1152+
interp_dict = self._generate_bilinear_interp_dict(xy_clip)
11591153

1160-
if on_panel is None:
1161-
# clip away points too close to or off the edges of the detector
1162-
xy_clip, on_panel = self.clip_to_panel(xy, buffer_edges=True)
1163-
else:
1164-
xy_clip = xy[on_panel]
1154+
# Set the output and return
1155+
int_xy[on_panel] = _interpolate_bilinear(img, **interp_dict)
1156+
return int_xy
11651157

1158+
def _generate_bilinear_interp_dict(
1159+
self,
1160+
xy_clip: np.ndarray,
1161+
) -> dict[str, np.ndarray]:
1162+
"""Compute bilinear interpolation multipliers and indices for the panel
1163+
1164+
If you are going to be using the same panel settings and performing
1165+
interpolation on multiple images, it is advised to run this beforehand
1166+
to precompute the interpolation parameters, so you can use them
1167+
repeatedly.
1168+
"""
11661169
# grab fractional pixel indices of clipped points
11671170
ij_frac = self.cartToPixel(xy_clip)
11681171

@@ -1182,20 +1185,24 @@ def interpolate_bilinear(self, xy, img, pad_with_nans=True,
11821185
j_ceil = j_floor + 1
11831186
j_ceil_img = _fix_indices(j_ceil, 0, self.cols - 1)
11841187

1185-
# first interpolate at top/bottom rows
1186-
row_floor_int = (j_ceil - ij_frac[:, 1]) * img[
1187-
i_floor_img, j_floor_img
1188-
] + (ij_frac[:, 1] - j_floor) * img[i_floor_img, j_ceil_img]
1189-
row_ceil_int = (j_ceil - ij_frac[:, 1]) * img[
1190-
i_ceil_img, j_floor_img
1191-
] + (ij_frac[:, 1] - j_floor) * img[i_ceil_img, j_ceil_img]
1192-
1193-
# next interpolate across cols
1194-
int_vals = (i_ceil - ij_frac[:, 0]) * row_floor_int + (
1195-
ij_frac[:, 0] - i_floor
1196-
) * row_ceil_int
1197-
int_xy[on_panel] = int_vals
1198-
return int_xy
1188+
# Compute differences between raw coordinates to use for interpolation
1189+
j_ceil_sub = j_ceil - ij_frac[:, 1]
1190+
j_floor_sub = ij_frac[:, 1] - j_floor
1191+
i_ceil_sub = i_ceil - ij_frac[:, 0]
1192+
i_floor_sub = ij_frac[:, 0] - i_floor
1193+
1194+
return {
1195+
# Compute interpolation multipliers for every pixel
1196+
'cc': j_ceil_sub * i_ceil_sub,
1197+
'fc': j_floor_sub * i_ceil_sub,
1198+
'cf': j_ceil_sub * i_floor_sub,
1199+
'ff': j_floor_sub * i_floor_sub,
1200+
# Store needed pixel indices
1201+
'i_floor_img': i_floor_img,
1202+
'j_floor_img': j_floor_img,
1203+
'i_ceil_img': i_ceil_img,
1204+
'j_ceil_img': j_ceil_img,
1205+
}
11991206

12001207
def make_powder_rings(
12011208
self,
@@ -2100,3 +2107,63 @@ def _row_edge_vec(rows, pixel_size_row):
21002107

21012108
def _col_edge_vec(cols, pixel_size_col):
21022109
return pixel_size_col * (np.arange(cols + 1) - 0.5 * cols)
2110+
2111+
2112+
@numba.njit(nogil=True, cache=True)
2113+
def _interpolate_bilinear(
2114+
img: np.ndarray,
2115+
cc: np.ndarray,
2116+
fc: np.ndarray,
2117+
cf: np.ndarray,
2118+
ff: np.ndarray,
2119+
i_floor_img: np.ndarray,
2120+
j_floor_img: np.ndarray,
2121+
i_ceil_img: np.ndarray,
2122+
j_ceil_img: np.ndarray,
2123+
) -> np.ndarray:
2124+
# The math is faster and uses the GIL less (which is more
2125+
# multi-threading friendly) when we run this code in numba.
2126+
result = np.zeros(i_floor_img.shape[0], dtype=img.dtype)
2127+
on_panel_idx = np.arange(i_floor_img.shape[0])
2128+
_interpolate_bilinear_in_place(
2129+
img,
2130+
cc,
2131+
fc,
2132+
cf,
2133+
ff,
2134+
i_floor_img,
2135+
j_floor_img,
2136+
i_ceil_img,
2137+
j_ceil_img,
2138+
on_panel_idx,
2139+
result,
2140+
)
2141+
return result
2142+
2143+
2144+
@numba.njit(nogil=True, cache=True)
2145+
def _interpolate_bilinear_in_place(
2146+
img: np.ndarray,
2147+
cc: np.ndarray,
2148+
fc: np.ndarray,
2149+
cf: np.ndarray,
2150+
ff: np.ndarray,
2151+
i_floor_img: np.ndarray,
2152+
j_floor_img: np.ndarray,
2153+
i_ceil_img: np.ndarray,
2154+
j_ceil_img: np.ndarray,
2155+
on_panel_idx: np.ndarray,
2156+
output_img: np.ndarray,
2157+
):
2158+
# The math is faster and uses the GIL less (which is more
2159+
# multi-threading friendly) when we run this code in numba.
2160+
# Running in-place eliminates some intermediary arrays for
2161+
# even faster performance.
2162+
for i in range(on_panel_idx.shape[0]):
2163+
idx = on_panel_idx[i]
2164+
output_img[idx] += (
2165+
cc[i] * img[i_floor_img[i], j_floor_img[i]] +
2166+
fc[i] * img[i_floor_img[i], j_ceil_img[i]] +
2167+
cf[i] * img[i_ceil_img[i], j_floor_img[i]] +
2168+
ff[i] * img[i_ceil_img[i], j_ceil_img[i]]
2169+
)

hexrd/projections/polar.py

Lines changed: 92 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import numpy as np
22

33
from hexrd import constants
4+
from hexrd.instrument.detector import _interpolate_bilinear_in_place
45
from hexrd.material.crystallography import PlaneData
56
from hexrd.xrdutil.utils import (
67
_project_on_detector_cylinder,
@@ -77,13 +78,15 @@ def __init__(self, plane_data, instrument,
7778
self._instrument = instrument
7879

7980
self._coordinate_mapping = None
81+
self._nan_mask = None
8082
self._cache_coordinate_map = cache_coordinate_map
8183
if cache_coordinate_map:
8284
# It is important to generate the cached map now, rather than
8385
# later, because this object might be sent to other processes
8486
# for parallelization, and it will be faster if the mapping
8587
# is already generated.
8688
self._coordinate_mapping = self._generate_coordinate_mapping()
89+
self._nan_mask = self._generate_nan_mask(self._coordinate_mapping)
8790

8891
@property
8992
def instrument(self):
@@ -261,13 +264,21 @@ def warp_image(self, image_dict, pad_with_nans=False,
261264
if self.cache_coordinate_map:
262265
# The mapping should have already been generated.
263266
mapping = self._coordinate_mapping
267+
nan_mask = self._nan_mask
264268
else:
265269
# Otherwise, we must generate it every time
266270
mapping = self._generate_coordinate_mapping()
271+
# FIXME: this performs a bilinear interpolation
272+
# each time. Maybe it doesn't matter that much
273+
# since the interpolation is very fast now, but
274+
# it'd be nice if we could figure out another
275+
# way to do it.
276+
nan_mask = self._generate_nan_mask(mapping)
267277

268278
return self._warp_image_from_coordinate_map(
269279
image_dict,
270280
mapping,
281+
nan_mask,
271282
pad_with_nans=pad_with_nans,
272283
do_interpolation=do_interpolation,
273284
)
@@ -312,66 +323,104 @@ def _generate_coordinate_mapping(self) -> dict[str, dict[str, np.ndarray]]:
312323
xypts[on_plane, :] = valid_xys
313324

314325
_, on_panel = panel.clip_to_panel(xypts, buffer_edges=True)
326+
on_panel_idx = np.where(on_panel)[0]
327+
xy_clip = xypts[on_panel_idx]
328+
329+
bilinear_interp_dict = panel._generate_bilinear_interp_dict(
330+
xy_clip,
331+
)
315332

316333
mapping[detector_id] = {
317334
'xypts': xypts,
318-
'on_panel': on_panel,
335+
'on_panel_idx': on_panel_idx,
336+
'bilinear_interp_dict': bilinear_interp_dict,
319337
}
320338

321339
return mapping
322340

323-
def _warp_image_from_coordinate_map(
324-
self,
325-
image_dict: dict[str, np.ndarray],
326-
coordinate_map: dict[str, dict[str, np.ndarray]],
327-
pad_with_nans: bool = False,
328-
do_interpolation=True) -> np.ma.MaskedArray:
329-
330-
panel_buffer_fill_value = np.nan
331-
img_dict = dict.fromkeys(self.detectors)
332-
nan_mask = None
333-
for detector_id, panel in self.detectors.items():
334-
# Make a copy since we may modify
335-
img = image_dict[detector_id].copy()
341+
def _generate_nan_mask(
342+
self,
343+
coordinate_map: dict[str, dict[str, np.ndarray]],
344+
) -> np.ndarray:
345+
"""Generate the nan mask
336346
337-
# Before warping, mask out any pixels that are invalid,
338-
# so that they won't affect the results.
347+
This saves time during repeated calls to warp_image(),
348+
since the nan mask should stay the same and not change.
349+
"""
350+
mapping = coordinate_map
351+
# Generate the nan mask that we will use
352+
nan_mask = np.ones(self.shape, dtype=bool).flatten()
353+
for detector_id, panel in self.detectors.items():
354+
on_panel_idx = mapping[detector_id]['on_panel_idx']
355+
xypts = mapping[detector_id]['xypts']
356+
interp_dict = mapping[detector_id]['bilinear_interp_dict']
357+
358+
# To reproduce old behavior, perform a bilinear
359+
# interpolation so that if any point has neighboring
360+
# pixels that are nan, that point will also be excluded.
361+
dummy_img = np.zeros(panel.shape)
339362
buffer = panel_buffer_as_2d_array(panel)
340-
if (np.issubdtype(type(panel_buffer_fill_value), np.floating) and
341-
not np.issubdtype(img.dtype, np.floating)):
342-
# Convert to float. This is especially important
343-
# for nan, since it is a float...
344-
img = img.astype(float)
363+
dummy_img[~buffer] = np.nan
364+
365+
output = np.full(len(xypts), np.nan)
366+
output[on_panel_idx] = 0
367+
_interpolate_bilinear_in_place(
368+
dummy_img,
369+
**interp_dict,
370+
on_panel_idx=on_panel_idx,
371+
output_img=output,
372+
)
345373

346-
img[~buffer] = panel_buffer_fill_value
374+
nan_mask[~np.isnan(output)] = False
347375

348-
xypts = coordinate_map[detector_id]['xypts']
349-
on_panel = coordinate_map[detector_id]['on_panel']
376+
return nan_mask.reshape(self.shape)
377+
378+
def _warp_image_from_coordinate_map(
379+
self,
380+
image_dict: dict[str, np.ndarray],
381+
coordinate_map: dict[str, dict[str, np.ndarray]],
382+
nan_mask: np.ndarray,
383+
pad_with_nans: bool = False,
384+
do_interpolation=True,
385+
) -> np.ma.MaskedArray:
386+
first_det = next(iter(self.detectors))
387+
# This is a flat image. We'll reshape at the end.
388+
summed_img = np.zeros(len(coordinate_map[first_det]['xypts']))
389+
for detector_id, panel in self.detectors.items():
390+
img = image_dict[detector_id]
391+
panel_map = coordinate_map[detector_id]
392+
393+
xypts = panel_map['xypts']
394+
on_panel_idx = panel_map['on_panel_idx']
395+
interp_dict = panel_map['bilinear_interp_dict']
350396

351397
if do_interpolation:
352-
this_img = panel.interpolate_bilinear(
353-
xypts, img,
354-
pad_with_nans=pad_with_nans,
355-
on_panel=on_panel).reshape(self.shape)
356-
else:
357-
this_img = panel.interpolate_nearest(
358-
xypts, img,
359-
pad_with_nans=pad_with_nans).reshape(self.shape)
360-
361-
# It is faster to keep track of the global nans like this
362-
# rather than the previous way we were doing it...
363-
img_nans = np.isnan(this_img)
364-
if nan_mask is None:
365-
nan_mask = img_nans
398+
# It's faster if we do _interpolate_bilinear ourselves,
399+
# since we already have all appropriate options set up.
400+
_interpolate_bilinear_in_place(
401+
img,
402+
**interp_dict,
403+
on_panel_idx=on_panel_idx,
404+
output_img=summed_img,
405+
)
366406
else:
367-
nan_mask = np.logical_and(img_nans, nan_mask)
407+
summed_img += panel.interpolate_nearest(
408+
xypts,
409+
img,
410+
# DON'T pad with nans, so we can sum images together
411+
# correctly. We'll pad with nans later.
412+
pad_with_nans=False,
413+
)
414+
415+
# Now reshape the image to the appropriate shape
416+
output_img = summed_img.reshape(self.shape)
368417

369-
this_img[img_nans] = 0
370-
img_dict[detector_id] = this_img
418+
if pad_with_nans:
419+
# We pad with nans manually here
420+
output_img[nan_mask] = np.nan
371421

372-
summed_img = np.sum(list(img_dict.values()), axis=0)
373422
return np.ma.masked_array(
374-
data=summed_img, mask=nan_mask, fill_value=0.
423+
data=output_img, mask=nan_mask, fill_value=0.
375424
)
376425

377426
def tth_to_pixel(self, tth):

0 commit comments

Comments
 (0)