diff --git a/hexrd/fitting/grains.py b/hexrd/fitting/grains.py index c65ae37d..36119a22 100644 --- a/hexrd/fitting/grains.py +++ b/hexrd/fitting/grains.py @@ -78,7 +78,33 @@ def fitGrain(gFull, instrument, reflections_dict, gFit = gFull[gFlag] - fitArgs = (gFull, gFlag, instrument, reflections_dict, + # objFuncFitGrain can run *significantly* faster if we convert the + # results to use a dictionary instead of lists or numpy arrays. + # Do that conversion here, if necessary. + new_reflections_dict = {} + for det_key, results in reflections_dict.items(): + if not isinstance(results, (list, np.ndarray)) or len(results) == 0: + # Maybe it's already a dict... + new_reflections_dict[det_key] = results + continue + + if isinstance(results, list): + hkls = np.atleast_2d( + np.vstack([x[2] for x in results]) + ).T + meas_xyo = np.atleast_2d( + np.vstack([np.r_[x[7], x[6][-1]] for x in results]) + ) + else: + hkls = np.atleast_2d(results[:, 2:5]).T + meas_xyo = np.atleast_2d(results[:, [15, 16, 12]]) + + new_reflections_dict[det_key] = { + 'hkls': hkls, + 'meas_xyo': meas_xyo, + } + + fitArgs = (gFull, gFlag, instrument, new_reflections_dict, bMat, wavelength, omePeriod) results = optimize.leastsq(objFuncFitGrain, gFit, args=fitArgs, diag=1./gScl[gFlag].flatten(), @@ -185,7 +211,7 @@ def objFuncFitGrain(gFit, gFull, gFlag, instrument.detector_parameters[det_key]) results = reflections_dict[det_key] - if len(results) == 0: + if not isinstance(results, dict) and len(results) == 0: continue """ @@ -214,6 +240,9 @@ def objFuncFitGrain(gFit, gFull, gFlag, elif isinstance(results, np.ndarray): hkls = np.atleast_2d(results[:, 2:5]).T meas_xyo = np.atleast_2d(results[:, [15, 16, 12]]) + elif isinstance(results, dict): + hkls = results['hkls'] + meas_xyo = results['meas_xyo'] # distortion handling if panel.distortion is not None: diff --git a/hexrd/imageseries/load/framecache.py b/hexrd/imageseries/load/framecache.py index 74a22160..55aa787b 100644 --- a/hexrd/imageseries/load/framecache.py +++ b/hexrd/imageseries/load/framecache.py @@ -4,10 +4,11 @@ import os from threading import Lock +import h5py import numpy as np -from scipy.sparse import csr_matrix import yaml -import h5py +from scipy.sparse import csr_matrix +from scipy.sparse.compressed import csr_sample_values from . import ImageSeriesAdapter, RegionType from ..imageseriesiter import ImageSeriesIterator @@ -209,6 +210,19 @@ def _load_framelist_if_needed(self): def __getitem__(self, key): self._load_framelist_if_needed() + if not isinstance(key, int): + # Extract only what we need from the sparse array + # using fancy indexing before we convert it to a + # numpy array. + mat = self._framelist[key[0]] + if len(key) == 3: + # This is definitely used frequently and needs to + # be performant. + return _extract_sparse_values(mat, key[1], key[2]) + elif len(key) == 2: + # Not sure if this will actually be used. + return mat[key[1]].toarray() + return self._framelist[key].toarray() def __iter__(self): @@ -313,3 +327,28 @@ def read_list_arrays_method_thread(i): range(num_frames))) return framelist + + +def _extract_sparse_values( + mat: csr_matrix, + row: np.ndarray, + col: np.ndarray, +) -> np.ndarray: + # This was first copied from here: https://github.com/scipy/scipy/blob/a465e2ce014c1b20b0e4b949e46361e5c2fb727e/scipy/sparse/_compressed.py#L556-L569 + # And then subsequently modified to return the internal `val` array. + + # It uses the `csr_sample_values()` function to extract values. This is + # excellent because it skips the creation of a new sparse array (and + # subsequent conversion to a numpy array *again*). It provides a nearly + # 10% performance boost for `pull_spots()`. + idx_dtype = mat.indices.dtype + M, N = mat._swap(mat.shape) + major, minor = mat._swap((row, col)) + major = np.asarray(major, dtype=idx_dtype) + minor = np.asarray(minor, dtype=idx_dtype) + + val = np.empty(major.size, dtype=mat.dtype) + csr_sample_values(M, N, mat.indptr, mat.indices, mat.data, + major.size, major.ravel(), minor.ravel(), val) + + return val.reshape(major.shape) diff --git a/hexrd/imageseries/load/hdf5.py b/hexrd/imageseries/load/hdf5.py index 606bcd95..0af2e76e 100644 --- a/hexrd/imageseries/load/hdf5.py +++ b/hexrd/imageseries/load/hdf5.py @@ -64,6 +64,11 @@ def __del__(self): warnings.warn("HDF5ImageSeries could not close h5 file") def __getitem__(self, key): + if not isinstance(key, int): + # FIXME: we do not yet support fancy indexing here. + # Fully expand the array then apply the fancy indexing. + return self[key[0]][*key[1:]] + if self._ndim == 2: if key != 0: raise IndexError( diff --git a/hexrd/imageseries/load/imagefiles.py b/hexrd/imageseries/load/imagefiles.py index 532d2c00..21d28d10 100644 --- a/hexrd/imageseries/load/imagefiles.py +++ b/hexrd/imageseries/load/imagefiles.py @@ -45,6 +45,11 @@ def __len__(self): return self._nframes def __getitem__(self, key): + if not isinstance(key, int): + # FIXME: we do not yet support fancy indexing here. + # Fully expand the array then apply the fancy indexing. + return self[key[0]][*key[1:]] + if self.singleframes: frame = None filename = self._files[key] diff --git a/hexrd/imageseries/load/rawimage.py b/hexrd/imageseries/load/rawimage.py index a272ea1b..dc953621 100644 --- a/hexrd/imageseries/load/rawimage.py +++ b/hexrd/imageseries/load/rawimage.py @@ -111,6 +111,11 @@ def __iter__(self): return ImageSeriesIterator(self) def __getitem__(self, key): + if not isinstance(key, int): + # FIXME: we do not yet support fancy indexing here. + # Fully expand the array then apply the fancy indexing. + return self[key[0]][*key[1:]] + count = key * self._frame_bytes + self.skipbytes # Ensure reading a frame the file is thread-safe diff --git a/hexrd/imageseries/process.py b/hexrd/imageseries/process.py index 99924a2a..31d910d9 100644 --- a/hexrd/imageseries/process.py +++ b/hexrd/imageseries/process.py @@ -44,7 +44,22 @@ def __init__(self, imser, oplist, **kwargs): self.addop(self.GAUSS_LAPLACE, self._gauss_laplace) def __getitem__(self, key): - return self._process_frame(self._get_index(key)) + if isinstance(key, int): + idx = key + rest = [] + else: + # Handle fancy indexing + idx = key[0] + rest = key[1:] + + idx = self._get_index(idx) + + if rest: + arg = tuple([idx, *rest]) + else: + arg = idx + + return self._process_frame(arg) def _get_index(self, key): return self._frames[key] if self._hasframelist else key diff --git a/hexrd/instrument/hedm_instrument.py b/hexrd/instrument/hedm_instrument.py index 7b810e92..b9c831cb 100644 --- a/hexrd/instrument/hedm_instrument.py +++ b/hexrd/instrument/hedm_instrument.py @@ -1723,7 +1723,7 @@ def pull_spots(self, plane_data, grain_params, contains_signal = False for i_frame in frame_indices: contains_signal = contains_signal or np.any( - ome_imgser[i_frame][ii, jj] > threshold + ome_imgser[i_frame, ii, jj] > threshold ) compl.append(contains_signal) patch_output.append((ii, jj, frame_indices)) @@ -1792,7 +1792,7 @@ def pull_spots(self, plane_data, grain_params, contains_signal = False patch_data_raw = [] for i_frame in frame_indices: - tmp = ome_imgser[i_frame][ijs[0], ijs[1]] + tmp = ome_imgser[i_frame, ijs[0], ijs[1]] contains_signal = contains_signal or np.any( tmp > threshold )