Skip to content

Commit 9c595e1

Browse files
authored
Merge pull request #728 from ChristosT/add-hdf5-support
Add fch5 an hdf5-based format for storing framecaches
2 parents 6213ab6 + 301a673 commit 9c595e1

File tree

5 files changed

+277
-34
lines changed

5 files changed

+277
-34
lines changed

conda.recipe/meta.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ requirements:
4242
- lxml >=4.9.2
4343
- fast-histogram
4444
- h5py
45+
- hdf5plugin
4546
- lmfit
4647
- matplotlib-base
4748
- numba

hexrd/imageseries/load/framecache.py

Lines changed: 80 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,17 @@
66
import numpy as np
77
from scipy.sparse import csr_matrix
88
import yaml
9+
import h5py
910

1011
from . import ImageSeriesAdapter
1112
from ..imageseriesiter import ImageSeriesIterator
1213
from .metadata import yamlmeta
14+
from hexrd.utils.hdf5 import unwrap_h5_to_dict
15+
from hexrd.utils.compatibility import h5py_read_string
16+
17+
import multiprocessing
18+
from concurrent.futures import ThreadPoolExecutor
19+
1320

1421
class FrameCacheImageSeriesAdapter(ImageSeriesAdapter):
1522
"""collection of images in HDF5 format"""
@@ -26,13 +33,25 @@ def __init__(self, fname, style='npz', **kwargs):
2633
self._framelist = []
2734
self._framelist_was_loaded = False
2835
self._load_framelist_lock = Lock()
36+
# TODO extract style from filename ?
37+
self._style = style.lower()
38+
39+
ncpus = multiprocessing.cpu_count()
40+
self._max_workers = kwargs.get('max_workers', ncpus)
2941

30-
if style.lower() in ('yml', 'yaml', 'test'):
42+
if self._style in ('yml', 'yaml', 'test'):
3143
self._from_yml = True
3244
self._load_yml()
33-
else:
45+
elif self._style == "npz":
3446
self._from_yml = False
3547
self._load_cache()
48+
elif self._style == "fch5":
49+
self._from_yml = False
50+
self._load_cache()
51+
else:
52+
raise TypeError(f"Unknown style format for loading data: {style}."
53+
"Known style formats: 'npz', 'fch5' 'yml', ",
54+
"'yaml', 'test'")
3655

3756
def _load_yml(self):
3857
with open(self._fname, "r") as f:
@@ -45,6 +64,29 @@ def _load_yml(self):
4564
self._meta = yamlmeta(d['meta'], path=self._cache)
4665

4766
def _load_cache(self):
67+
if self._style == 'fch5':
68+
self._load_cache_fch5()
69+
else:
70+
self._load_cache_npz()
71+
72+
def _load_cache_fch5(self):
73+
with h5py.File(self._fname, "r") as file:
74+
if 'HEXRD_FRAMECACHE_VERSION' not in file.attrs.keys():
75+
raise NotImplementedError("Unsupported file. "
76+
"HEXRD_FRAMECACHE_VERSION "
77+
"is missing!")
78+
version = file.attrs.get('HEXRD_FRAMECACHE_VERSION', 0)
79+
if version != 1:
80+
raise NotImplementedError("Framecache version is not "
81+
f"supported: {version}")
82+
83+
self._shape = file["shape"][()]
84+
self._nframes = file["nframes"][()]
85+
self._dtype = np.dtype(h5py_read_string(file["dtype"]))
86+
self._meta = {}
87+
unwrap_h5_to_dict(file["metadata"], self._meta)
88+
89+
def _load_cache_npz(self):
4890
arrs = np.load(self._fname)
4991
# HACK: while the loaded npz file has a getitem method
5092
# that mimicks a dict, it doesn't have a "pop" method.
@@ -79,6 +121,41 @@ def _load_cache(self):
79121

80122
def _load_framelist(self):
81123
"""load into list of csr sparse matrices"""
124+
if self._style == 'fch5':
125+
self._load_framelist_fch5()
126+
else:
127+
self._load_framelist_npz()
128+
129+
def _load_framelist_fch5(self):
130+
self._framelist = [None] * self._nframes
131+
with h5py.File(self._fname, "r") as file:
132+
frame_id = file["frame_ids"]
133+
data = file["data"]
134+
indices = file["indices"]
135+
136+
def read_list_arrays_method_thread(i):
137+
frame_data = data[frame_id[2*i]: frame_id[2*i+1]]
138+
frame_indices = indices[frame_id[2*i]: frame_id[2*i+1]]
139+
row = frame_indices[:, 0]
140+
col = frame_indices[:, 1]
141+
mat_data = frame_data[:, 0]
142+
frame = csr_matrix((mat_data, (row, col)),
143+
shape=self._shape,
144+
dtype=self._dtype)
145+
self._framelist[i] = frame
146+
return
147+
148+
kwargs = {
149+
"max_workers": self._max_workers,
150+
}
151+
with ThreadPoolExecutor(**kwargs) as executor:
152+
# Evaluate the results via `list()`, so that if an exception is
153+
# raised in a thread, it will be re-raised and visible to the
154+
# user.
155+
list(executor.map(read_list_arrays_method_thread,
156+
range(self._nframes)))
157+
158+
def _load_framelist_npz(self):
82159
self._framelist = []
83160
if self._from_yml:
84161
bpath = os.path.dirname(self._fname)
@@ -149,6 +226,6 @@ def __getitem__(self, key):
149226
def __iter__(self):
150227
return ImageSeriesIterator(self)
151228

152-
#@memoize
229+
# @memoize
153230
def __len__(self):
154231
return self._nframes

0 commit comments

Comments
 (0)