Skip to content

Commit 70ce7f4

Browse files
committed
Significantly improve fit-grains performance
Profiling showed that the list comprehension in `objFuncFitGrain` was taking a significant amount of time. This is because that function is called repeatedly during the least squares fit, and repeatedly performing the list comprehension would take a substantial amount of time. Performing the list comprehension outside of `objFuncFitGrain` was shown to provide a substantial speedup. Signed-off-by: Patrick Avery <[email protected]>
1 parent 10a6691 commit 70ce7f4

File tree

1 file changed

+31
-2
lines changed

1 file changed

+31
-2
lines changed

hexrd/fitting/grains.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,33 @@ def fitGrain(gFull, instrument, reflections_dict,
7878

7979
gFit = gFull[gFlag]
8080

81-
fitArgs = (gFull, gFlag, instrument, reflections_dict,
81+
# objFuncFitGrain can run *significantly* faster if we convert the
82+
# results to use a dictionary instead of lists or numpy arrays.
83+
# Do that conversion here, if necessary.
84+
new_reflections_dict = {}
85+
for det_key, results in reflections_dict.items():
86+
if not isinstance(results, (list, np.ndarray)) or len(results) == 0:
87+
# Maybe it's already a dict...
88+
new_reflections_dict[det_key] = results
89+
continue
90+
91+
if isinstance(results, list):
92+
hkls = np.atleast_2d(
93+
np.vstack([x[2] for x in results])
94+
).T
95+
meas_xyo = np.atleast_2d(
96+
np.vstack([np.r_[x[7], x[6][-1]] for x in results])
97+
)
98+
else:
99+
hkls = np.atleast_2d(results[:, 2:5]).T
100+
meas_xyo = np.atleast_2d(results[:, [15, 16, 12]])
101+
102+
new_reflections_dict[det_key] = {
103+
'hkls': hkls,
104+
'meas_xyo': meas_xyo,
105+
}
106+
107+
fitArgs = (gFull, gFlag, instrument, new_reflections_dict,
82108
bMat, wavelength, omePeriod)
83109
results = optimize.leastsq(objFuncFitGrain, gFit, args=fitArgs,
84110
diag=1./gScl[gFlag].flatten(),
@@ -185,7 +211,7 @@ def objFuncFitGrain(gFit, gFull, gFlag,
185211
instrument.detector_parameters[det_key])
186212

187213
results = reflections_dict[det_key]
188-
if len(results) == 0:
214+
if not isinstance(results, dict) and len(results) == 0:
189215
continue
190216

191217
"""
@@ -214,6 +240,9 @@ def objFuncFitGrain(gFit, gFull, gFlag,
214240
elif isinstance(results, np.ndarray):
215241
hkls = np.atleast_2d(results[:, 2:5]).T
216242
meas_xyo = np.atleast_2d(results[:, [15, 16, 12]])
243+
elif isinstance(results, dict):
244+
hkls = results['hkls']
245+
meas_xyo = results['meas_xyo']
217246

218247
# distortion handling
219248
if panel.distortion is not None:

0 commit comments

Comments
 (0)