Skip to content

Commit 2a3dc19

Browse files
authored
Merge pull request #821 from HEXRD/csr-array
Avoid using `csr_sample_values`
2 parents cfa0d3e + e830e1e commit 2a3dc19

File tree

1 file changed

+8
-34
lines changed

1 file changed

+8
-34
lines changed

hexrd/imageseries/load/framecache.py

Lines changed: 8 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,7 @@
77
import h5py
88
import numpy as np
99
import yaml
10-
from scipy.sparse import csr_matrix
11-
12-
# FIXME: figure out if there is a public way to import this function
13-
from scipy.sparse._compressed import csr_sample_values
10+
from scipy.sparse import csr_array
1411

1512
from . import ImageSeriesAdapter, RegionType
1613
from ..imageseriesiter import ImageSeriesIterator
@@ -220,7 +217,9 @@ def __getitem__(self, key):
220217
if len(key) == 3:
221218
# This is definitely used frequently and needs to
222219
# be performant.
223-
return _extract_sparse_values(mat, key[1], key[2])
220+
ind1 = key[1]
221+
ind2 = key[2]
222+
return mat[ind1.ravel(), ind2.ravel()].reshape(ind1.shape)
224223
elif len(key) == 2:
225224
# Not sure if this will actually be used.
226225
return mat[key[1]].toarray()
@@ -264,15 +263,15 @@ def _load_framecache_npz(
264263
num_frames: int,
265264
shape: tuple[int, int],
266265
dtype: np.dtype,
267-
) -> list[csr_matrix]:
266+
) -> list[csr_array]:
268267

269268
framelist = []
270269
arrs = np.load(filepath)
271270
for i in range(num_frames):
272271
row = arrs[f"{i}_row"]
273272
col = arrs[f"{i}_col"]
274273
data = arrs[f"{i}_data"]
275-
frame = csr_matrix((data, (row, col)),
274+
frame = csr_array((data, (row, col)),
276275
shape=shape,
277276
dtype=dtype)
278277

@@ -294,7 +293,7 @@ def _load_framecache_fch5(
294293
shape: tuple[int, int],
295294
dtype: np.dtype,
296295
max_workers: int,
297-
) -> list[csr_matrix]:
296+
) -> list[csr_array]:
298297

299298
framelist = [None] * num_frames
300299

@@ -309,7 +308,7 @@ def read_list_arrays_method_thread(i):
309308
row = frame_indices[:, 0]
310309
col = frame_indices[:, 1]
311310
mat_data = frame_data[:, 0]
312-
frame = csr_matrix((mat_data, (row, col)),
311+
frame = csr_array((mat_data, (row, col)),
313312
shape=shape,
314313
dtype=dtype)
315314

@@ -329,28 +328,3 @@ def read_list_arrays_method_thread(i):
329328
range(num_frames)))
330329

331330
return framelist
332-
333-
334-
def _extract_sparse_values(
335-
mat: csr_matrix,
336-
row: np.ndarray,
337-
col: np.ndarray,
338-
) -> np.ndarray:
339-
# This was first copied from here: https://github.com/scipy/scipy/blob/a465e2ce014c1b20b0e4b949e46361e5c2fb727e/scipy/sparse/_compressed.py#L556-L569
340-
# And then subsequently modified to return the internal `val` array.
341-
342-
# It uses the `csr_sample_values()` function to extract values. This is
343-
# excellent because it skips the creation of a new sparse array (and
344-
# subsequent conversion to a numpy array *again*). It provides a nearly
345-
# 10% performance boost for `pull_spots()`.
346-
idx_dtype = mat.indices.dtype
347-
M, N = mat._swap(mat.shape)
348-
major, minor = mat._swap((row, col))
349-
major = np.asarray(major, dtype=idx_dtype)
350-
minor = np.asarray(minor, dtype=idx_dtype)
351-
352-
val = np.empty(major.size, dtype=mat.dtype)
353-
csr_sample_values(M, N, mat.indptr, mat.indices, mat.data,
354-
major.size, major.ravel(), minor.ravel(), val)
355-
356-
return val.reshape(major.shape)

0 commit comments

Comments
 (0)