Skip to content

Commit 20bb998

Browse files
authored
Merge pull request #816 from HEXRD/fit-grains-perf
Significantly improve fit-grains performance
2 parents 0d5b5f1 + 49851c7 commit 20bb998

File tree

7 files changed

+105
-7
lines changed

7 files changed

+105
-7
lines changed

hexrd/fitting/grains.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,33 @@ def fitGrain(gFull, instrument, reflections_dict,
7878

7979
gFit = gFull[gFlag]
8080

81-
fitArgs = (gFull, gFlag, instrument, reflections_dict,
81+
# objFuncFitGrain can run *significantly* faster if we convert the
82+
# results to use a dictionary instead of lists or numpy arrays.
83+
# Do that conversion here, if necessary.
84+
new_reflections_dict = {}
85+
for det_key, results in reflections_dict.items():
86+
if not isinstance(results, (list, np.ndarray)) or len(results) == 0:
87+
# Maybe it's already a dict...
88+
new_reflections_dict[det_key] = results
89+
continue
90+
91+
if isinstance(results, list):
92+
hkls = np.atleast_2d(
93+
np.vstack([x[2] for x in results])
94+
).T
95+
meas_xyo = np.atleast_2d(
96+
np.vstack([np.r_[x[7], x[6][-1]] for x in results])
97+
)
98+
else:
99+
hkls = np.atleast_2d(results[:, 2:5]).T
100+
meas_xyo = np.atleast_2d(results[:, [15, 16, 12]])
101+
102+
new_reflections_dict[det_key] = {
103+
'hkls': hkls,
104+
'meas_xyo': meas_xyo,
105+
}
106+
107+
fitArgs = (gFull, gFlag, instrument, new_reflections_dict,
82108
bMat, wavelength, omePeriod)
83109
results = optimize.leastsq(objFuncFitGrain, gFit, args=fitArgs,
84110
diag=1./gScl[gFlag].flatten(),
@@ -185,7 +211,7 @@ def objFuncFitGrain(gFit, gFull, gFlag,
185211
instrument.detector_parameters[det_key])
186212

187213
results = reflections_dict[det_key]
188-
if len(results) == 0:
214+
if not isinstance(results, dict) and len(results) == 0:
189215
continue
190216

191217
"""
@@ -214,6 +240,9 @@ def objFuncFitGrain(gFit, gFull, gFlag,
214240
elif isinstance(results, np.ndarray):
215241
hkls = np.atleast_2d(results[:, 2:5]).T
216242
meas_xyo = np.atleast_2d(results[:, [15, 16, 12]])
243+
elif isinstance(results, dict):
244+
hkls = results['hkls']
245+
meas_xyo = results['meas_xyo']
217246

218247
# distortion handling
219248
if panel.distortion is not None:

hexrd/imageseries/load/framecache.py

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,11 @@
44
import os
55
from threading import Lock
66

7+
import h5py
78
import numpy as np
8-
from scipy.sparse import csr_matrix
99
import yaml
10-
import h5py
10+
from scipy.sparse import csr_matrix
11+
from scipy.sparse.compressed import csr_sample_values
1112

1213
from . import ImageSeriesAdapter, RegionType
1314
from ..imageseriesiter import ImageSeriesIterator
@@ -209,6 +210,19 @@ def _load_framelist_if_needed(self):
209210

210211
def __getitem__(self, key):
211212
self._load_framelist_if_needed()
213+
if not isinstance(key, int):
214+
# Extract only what we need from the sparse array
215+
# using fancy indexing before we convert it to a
216+
# numpy array.
217+
mat = self._framelist[key[0]]
218+
if len(key) == 3:
219+
# This is definitely used frequently and needs to
220+
# be performant.
221+
return _extract_sparse_values(mat, key[1], key[2])
222+
elif len(key) == 2:
223+
# Not sure if this will actually be used.
224+
return mat[key[1]].toarray()
225+
212226
return self._framelist[key].toarray()
213227

214228
def __iter__(self):
@@ -313,3 +327,28 @@ def read_list_arrays_method_thread(i):
313327
range(num_frames)))
314328

315329
return framelist
330+
331+
332+
def _extract_sparse_values(
333+
mat: csr_matrix,
334+
row: np.ndarray,
335+
col: np.ndarray,
336+
) -> np.ndarray:
337+
# This was first copied from here: https://github.com/scipy/scipy/blob/a465e2ce014c1b20b0e4b949e46361e5c2fb727e/scipy/sparse/_compressed.py#L556-L569
338+
# And then subsequently modified to return the internal `val` array.
339+
340+
# It uses the `csr_sample_values()` function to extract values. This is
341+
# excellent because it skips the creation of a new sparse array (and
342+
# subsequent conversion to a numpy array *again*). It provides a nearly
343+
# 10% performance boost for `pull_spots()`.
344+
idx_dtype = mat.indices.dtype
345+
M, N = mat._swap(mat.shape)
346+
major, minor = mat._swap((row, col))
347+
major = np.asarray(major, dtype=idx_dtype)
348+
minor = np.asarray(minor, dtype=idx_dtype)
349+
350+
val = np.empty(major.size, dtype=mat.dtype)
351+
csr_sample_values(M, N, mat.indptr, mat.indices, mat.data,
352+
major.size, major.ravel(), minor.ravel(), val)
353+
354+
return val.reshape(major.shape)

hexrd/imageseries/load/hdf5.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,11 @@ def __del__(self):
6464
warnings.warn("HDF5ImageSeries could not close h5 file")
6565

6666
def __getitem__(self, key):
67+
if not isinstance(key, int):
68+
# FIXME: we do not yet support fancy indexing here.
69+
# Fully expand the array then apply the fancy indexing.
70+
return self[key[0]][*key[1:]]
71+
6772
if self._ndim == 2:
6873
if key != 0:
6974
raise IndexError(

hexrd/imageseries/load/imagefiles.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,11 @@ def __len__(self):
4545
return self._nframes
4646

4747
def __getitem__(self, key):
48+
if not isinstance(key, int):
49+
# FIXME: we do not yet support fancy indexing here.
50+
# Fully expand the array then apply the fancy indexing.
51+
return self[key[0]][*key[1:]]
52+
4853
if self.singleframes:
4954
frame = None
5055
filename = self._files[key]

hexrd/imageseries/load/rawimage.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,11 @@ def __iter__(self):
111111
return ImageSeriesIterator(self)
112112

113113
def __getitem__(self, key):
114+
if not isinstance(key, int):
115+
# FIXME: we do not yet support fancy indexing here.
116+
# Fully expand the array then apply the fancy indexing.
117+
return self[key[0]][*key[1:]]
118+
114119
count = key * self._frame_bytes + self.skipbytes
115120

116121
# Ensure reading a frame the file is thread-safe

hexrd/imageseries/process.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,22 @@ def __init__(self, imser, oplist, **kwargs):
4444
self.addop(self.GAUSS_LAPLACE, self._gauss_laplace)
4545

4646
def __getitem__(self, key):
47-
return self._process_frame(self._get_index(key))
47+
if isinstance(key, int):
48+
idx = key
49+
rest = []
50+
else:
51+
# Handle fancy indexing
52+
idx = key[0]
53+
rest = key[1:]
54+
55+
idx = self._get_index(idx)
56+
57+
if rest:
58+
arg = tuple([idx, *rest])
59+
else:
60+
arg = idx
61+
62+
return self._process_frame(arg)
4863

4964
def _get_index(self, key):
5065
return self._frames[key] if self._hasframelist else key

hexrd/instrument/hedm_instrument.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1723,7 +1723,7 @@ def pull_spots(self, plane_data, grain_params,
17231723
contains_signal = False
17241724
for i_frame in frame_indices:
17251725
contains_signal = contains_signal or np.any(
1726-
ome_imgser[i_frame][ii, jj] > threshold
1726+
ome_imgser[i_frame, ii, jj] > threshold
17271727
)
17281728
compl.append(contains_signal)
17291729
patch_output.append((ii, jj, frame_indices))
@@ -1792,7 +1792,7 @@ def pull_spots(self, plane_data, grain_params,
17921792
contains_signal = False
17931793
patch_data_raw = []
17941794
for i_frame in frame_indices:
1795-
tmp = ome_imgser[i_frame][ijs[0], ijs[1]]
1795+
tmp = ome_imgser[i_frame, ijs[0], ijs[1]]
17961796
contains_signal = contains_signal or np.any(
17971797
tmp > threshold
17981798
)

0 commit comments

Comments
 (0)