Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 31 additions & 2 deletions hexrd/fitting/grains.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,33 @@ def fitGrain(gFull, instrument, reflections_dict,

gFit = gFull[gFlag]

fitArgs = (gFull, gFlag, instrument, reflections_dict,
# objFuncFitGrain can run *significantly* faster if we convert the
# results to use a dictionary instead of lists or numpy arrays.
# Do that conversion here, if necessary.
new_reflections_dict = {}
for det_key, results in reflections_dict.items():
if not isinstance(results, (list, np.ndarray)) or len(results) == 0:
# Maybe it's already a dict...
new_reflections_dict[det_key] = results
continue

if isinstance(results, list):
hkls = np.atleast_2d(
np.vstack([x[2] for x in results])
).T
meas_xyo = np.atleast_2d(
np.vstack([np.r_[x[7], x[6][-1]] for x in results])
)
else:
hkls = np.atleast_2d(results[:, 2:5]).T
meas_xyo = np.atleast_2d(results[:, [15, 16, 12]])

new_reflections_dict[det_key] = {
'hkls': hkls,
'meas_xyo': meas_xyo,
}

fitArgs = (gFull, gFlag, instrument, new_reflections_dict,
bMat, wavelength, omePeriod)
results = optimize.leastsq(objFuncFitGrain, gFit, args=fitArgs,
diag=1./gScl[gFlag].flatten(),
Expand Down Expand Up @@ -185,7 +211,7 @@ def objFuncFitGrain(gFit, gFull, gFlag,
instrument.detector_parameters[det_key])

results = reflections_dict[det_key]
if len(results) == 0:
if not isinstance(results, dict) and len(results) == 0:
continue

"""
Expand Down Expand Up @@ -214,6 +240,9 @@ def objFuncFitGrain(gFit, gFull, gFlag,
elif isinstance(results, np.ndarray):
hkls = np.atleast_2d(results[:, 2:5]).T
meas_xyo = np.atleast_2d(results[:, [15, 16, 12]])
elif isinstance(results, dict):
hkls = results['hkls']
meas_xyo = results['meas_xyo']

# distortion handling
if panel.distortion is not None:
Expand Down
43 changes: 41 additions & 2 deletions hexrd/imageseries/load/framecache.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
import os
from threading import Lock

import h5py
import numpy as np
from scipy.sparse import csr_matrix
import yaml
import h5py
from scipy.sparse import csr_matrix
from scipy.sparse.compressed import csr_sample_values

from . import ImageSeriesAdapter, RegionType
from ..imageseriesiter import ImageSeriesIterator
Expand Down Expand Up @@ -209,6 +210,19 @@ def _load_framelist_if_needed(self):

def __getitem__(self, key):
self._load_framelist_if_needed()
if not isinstance(key, int):
# Extract only what we need from the sparse array
# using fancy indexing before we convert it to a
# numpy array.
mat = self._framelist[key[0]]
if len(key) == 3:
# This is definitely used frequently and needs to
# be performant.
return _extract_sparse_values(mat, key[1], key[2])
elif len(key) == 2:
# Not sure if this will actually be used.
return mat[key[1]].toarray()

return self._framelist[key].toarray()

def __iter__(self):
Expand Down Expand Up @@ -313,3 +327,28 @@ def read_list_arrays_method_thread(i):
range(num_frames)))

return framelist


def _extract_sparse_values(
mat: csr_matrix,
row: np.ndarray,
col: np.ndarray,
) -> np.ndarray:
# This was first copied from here: https://github.com/scipy/scipy/blob/a465e2ce014c1b20b0e4b949e46361e5c2fb727e/scipy/sparse/_compressed.py#L556-L569
# And then subsequently modified to return the internal `val` array.

# It uses the `csr_sample_values()` function to extract values. This is
# excellent because it skips the creation of a new sparse array (and
# subsequent conversion to a numpy array *again*). It provides a nearly
# 10% performance boost for `pull_spots()`.
idx_dtype = mat.indices.dtype
M, N = mat._swap(mat.shape)
major, minor = mat._swap((row, col))
major = np.asarray(major, dtype=idx_dtype)
minor = np.asarray(minor, dtype=idx_dtype)

val = np.empty(major.size, dtype=mat.dtype)
csr_sample_values(M, N, mat.indptr, mat.indices, mat.data,
major.size, major.ravel(), minor.ravel(), val)

return val.reshape(major.shape)
5 changes: 5 additions & 0 deletions hexrd/imageseries/load/hdf5.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,11 @@ def __del__(self):
warnings.warn("HDF5ImageSeries could not close h5 file")

def __getitem__(self, key):
if not isinstance(key, int):
# FIXME: we do not yet support fancy indexing here.
# Fully expand the array then apply the fancy indexing.
return self[key[0]][*key[1:]]

if self._ndim == 2:
if key != 0:
raise IndexError(
Expand Down
5 changes: 5 additions & 0 deletions hexrd/imageseries/load/imagefiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,11 @@ def __len__(self):
return self._nframes

def __getitem__(self, key):
if not isinstance(key, int):
# FIXME: we do not yet support fancy indexing here.
# Fully expand the array then apply the fancy indexing.
return self[key[0]][*key[1:]]

if self.singleframes:
frame = None
filename = self._files[key]
Expand Down
5 changes: 5 additions & 0 deletions hexrd/imageseries/load/rawimage.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,11 @@ def __iter__(self):
return ImageSeriesIterator(self)

def __getitem__(self, key):
if not isinstance(key, int):
# FIXME: we do not yet support fancy indexing here.
# Fully expand the array then apply the fancy indexing.
return self[key[0]][*key[1:]]

count = key * self._frame_bytes + self.skipbytes

# Ensure reading a frame the file is thread-safe
Expand Down
17 changes: 16 additions & 1 deletion hexrd/imageseries/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,22 @@ def __init__(self, imser, oplist, **kwargs):
self.addop(self.GAUSS_LAPLACE, self._gauss_laplace)

def __getitem__(self, key):
return self._process_frame(self._get_index(key))
if isinstance(key, int):
idx = key
rest = []
else:
# Handle fancy indexing
idx = key[0]
rest = key[1:]

idx = self._get_index(idx)

if rest:
arg = tuple([idx, *rest])
else:
arg = idx

return self._process_frame(arg)

def _get_index(self, key):
return self._frames[key] if self._hasframelist else key
Expand Down
4 changes: 2 additions & 2 deletions hexrd/instrument/hedm_instrument.py
Original file line number Diff line number Diff line change
Expand Up @@ -1723,7 +1723,7 @@ def pull_spots(self, plane_data, grain_params,
contains_signal = False
for i_frame in frame_indices:
contains_signal = contains_signal or np.any(
ome_imgser[i_frame][ii, jj] > threshold
ome_imgser[i_frame, ii, jj] > threshold
)
compl.append(contains_signal)
patch_output.append((ii, jj, frame_indices))
Expand Down Expand Up @@ -1792,7 +1792,7 @@ def pull_spots(self, plane_data, grain_params,
contains_signal = False
patch_data_raw = []
for i_frame in frame_indices:
tmp = ome_imgser[i_frame][ijs[0], ijs[1]]
tmp = ome_imgser[i_frame, ijs[0], ijs[1]]
contains_signal = contains_signal or np.any(
tmp > threshold
)
Expand Down
Loading