Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 8 additions & 34 deletions hexrd/imageseries/load/framecache.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,7 @@
import h5py
import numpy as np
import yaml
from scipy.sparse import csr_matrix

# FIXME: figure out if there is a public way to import this function
from scipy.sparse._compressed import csr_sample_values
from scipy.sparse import csr_array

from . import ImageSeriesAdapter, RegionType
from ..imageseriesiter import ImageSeriesIterator
Expand Down Expand Up @@ -220,7 +217,9 @@ def __getitem__(self, key):
if len(key) == 3:
# This is definitely used frequently and needs to
# be performant.
return _extract_sparse_values(mat, key[1], key[2])
ind1 = key[1]
ind2 = key[2]
return mat[ind1.ravel(), ind2.ravel()].reshape(ind1.shape)
elif len(key) == 2:
# Not sure if this will actually be used.
return mat[key[1]].toarray()
Expand Down Expand Up @@ -264,15 +263,15 @@ def _load_framecache_npz(
num_frames: int,
shape: tuple[int, int],
dtype: np.dtype,
) -> list[csr_matrix]:
) -> list[csr_array]:

framelist = []
arrs = np.load(filepath)
for i in range(num_frames):
row = arrs[f"{i}_row"]
col = arrs[f"{i}_col"]
data = arrs[f"{i}_data"]
frame = csr_matrix((data, (row, col)),
frame = csr_array((data, (row, col)),
shape=shape,
dtype=dtype)

Expand All @@ -294,7 +293,7 @@ def _load_framecache_fch5(
shape: tuple[int, int],
dtype: np.dtype,
max_workers: int,
) -> list[csr_matrix]:
) -> list[csr_array]:

framelist = [None] * num_frames

Expand All @@ -309,7 +308,7 @@ def read_list_arrays_method_thread(i):
row = frame_indices[:, 0]
col = frame_indices[:, 1]
mat_data = frame_data[:, 0]
frame = csr_matrix((mat_data, (row, col)),
frame = csr_array((mat_data, (row, col)),
shape=shape,
dtype=dtype)

Expand All @@ -329,28 +328,3 @@ def read_list_arrays_method_thread(i):
range(num_frames)))

return framelist


def _extract_sparse_values(
mat: csr_matrix,
row: np.ndarray,
col: np.ndarray,
) -> np.ndarray:
# This was first copied from here: https://github.com/scipy/scipy/blob/a465e2ce014c1b20b0e4b949e46361e5c2fb727e/scipy/sparse/_compressed.py#L556-L569
# And then subsequently modified to return the internal `val` array.

# It uses the `csr_sample_values()` function to extract values. This is
# excellent because it skips the creation of a new sparse array (and
# subsequent conversion to a numpy array *again*). It provides a nearly
# 10% performance boost for `pull_spots()`.
idx_dtype = mat.indices.dtype
M, N = mat._swap(mat.shape)
major, minor = mat._swap((row, col))
major = np.asarray(major, dtype=idx_dtype)
minor = np.asarray(minor, dtype=idx_dtype)

val = np.empty(major.size, dtype=mat.dtype)
csr_sample_values(M, N, mat.indptr, mat.indices, mat.data,
major.size, major.ravel(), minor.ravel(), val)

return val.reshape(major.shape)
Loading