Skip to content

Commit 9f2fa21

Browse files
authored
Merge pull request #53 from BAMresearch/fix_loaders
HDF loader and YAML loader need to match each other and the template IOSource better
2 parents 0680786 + 09206a5 commit 9f2fa21

File tree

4 files changed

+162
-112
lines changed

4 files changed

+162
-112
lines changed

src/modacor/io/hdf/hdf_loader.py

Lines changed: 72 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -1,92 +1,101 @@
11
# SPDX-License-Identifier: BSD-3-Clause
2-
# Copyright 2025 MoDaCor Authors
3-
#
4-
# Redistribution and use in source and binary forms, with or without modification,
5-
# are permitted provided that the following conditions are met:
6-
# 1. Redistributions of source code must retain the above copyright notice, this
7-
# list of conditions and the following disclaimer.
8-
# 2. Redistributions in binary form must reproduce the above copyright notice,
9-
# this list of conditions and the following disclaimer in the documentation
10-
# and/or other materials provided with the distribution.
11-
# 3. Neither the name of the copyright holder nor the names of its contributors
12-
# may be used to endorse or promote products derived from this software without
13-
# specific prior written permission.
14-
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS “AS IS” AND
15-
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
16-
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
17-
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
18-
# ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
19-
# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
20-
# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
21-
# ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
22-
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
23-
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
24-
25-
__license__ = "BSD-3-Clause"
26-
__copyright__ = "Copyright 2025 MoDaCor Authors"
27-
__status__ = "Alpha"
2+
# /usr/bin/env python3
3+
# -*- coding: utf-8 -*-
284

5+
from __future__ import annotations
6+
7+
from typing import Any, Optional
8+
9+
__coding__ = "utf-8"
10+
__author__ = "Tim Snow, Brian R. Pauw"
11+
__copyright__ = "Copyright 2025, The MoDaCor team"
12+
__date__ = "22/10/2025"
13+
__status__ = "Development" # "Development", "Production"
14+
# end of header and standard imports
15+
16+
__all__ = ["HDFLoader"]
2917

3018
from logging import WARNING
31-
from os.path import abspath
19+
from pathlib import Path
3220

3321
import h5py
22+
import numpy as np
3423

35-
from modacor.dataclasses.basedata import BaseData
3624
from modacor.dataclasses.messagehandler import MessageHandler
3725

26+
# from modacor.dataclasses.basedata import BaseData
27+
from modacor.io.io_source import ArraySlice
28+
3829
from ..io_source import IoSource
3930

4031

4132
class HDFLoader(IoSource):
42-
def __init__(self, source_reference: str, logging_level=WARNING):
33+
_data_cache: dict[str, np.ndarray] = None
34+
_file_path: Path | None = None
35+
_static_metadata_cache: dict[str, Any] = None
36+
37+
def __init__(self, source_reference: str, logging_level=WARNING, resource_location: Path | str | None = None):
4338
super().__init__(source_reference)
44-
self.hdf_logger = MessageHandler(level=logging_level, name="hdf5logger")
45-
self._file_path = None
46-
self._file_reference = None
39+
self.logger = MessageHandler(level=logging_level, name="HDFLoader")
40+
self._file_path = Path(resource_location) if resource_location is not None else None
41+
# self._file_reference = None # let's not leave open file references lying around if we can help it.
4742
self._file_datasets = []
4843
self._file_datasets_shapes = {}
44+
self._file_datasets_dtypes = {}
45+
self._data_cache = {}
46+
self._static_metadata_cache = {}
4947

50-
def _open_file(self, file_path=None):
51-
if file_path is None:
52-
error = "No filepath given"
53-
self.hdf_logger.log.error(error)
54-
raise OSError(error)
55-
56-
try:
57-
self._file_reference = h5py.File(file_path, "r")
58-
self._file_path = abspath(file_path)
59-
self._file_reference.visititems(self._find_datasets)
60-
except OSError as error:
61-
self.hdf_logger.logger.error(error)
62-
raise OSError(error)
63-
64-
def _close_file(self):
48+
def _preload(self):
49+
assert self._file_path.is_file(), self.logger.error(f"HDF5 file {self._file_path} does not exist.")
6550
try:
66-
self._file_reference.close()
67-
self._file_path = None
68-
self._file_reference = None
69-
self._file_datasets.clear()
70-
self._file_datasets_shapes.clear()
51+
with h5py.File(self._file_path, "r") as f:
52+
f.visititems(self._find_datasets)
7153
except OSError as error:
72-
self.hdf_logger.log.error(error)
54+
self.logger.log.error(error)
7355
raise OSError(error)
7456

7557
def _find_datasets(self, path_name, path_object):
7658
"""
7759
An internal function to be used to walk the tree of an HDF5 file and return a list of
7860
the datasets within
7961
"""
80-
if isinstance(self._file_reference[path_name], h5py._hl.dataset.Dataset):
62+
if isinstance(path_object, h5py._hl.dataset.Dataset):
8163
self._file_datasets.append(path_name)
82-
self._file_datasets_shapes[path_name] = self._file_reference[path_name].shape
83-
84-
def get_data(self, data_key: str) -> BaseData:
85-
raise (NotImplementedError("get_data method not yet implemented in HDFLoader class."))
64+
self._file_datasets_shapes[path_name] = path_object.shape
65+
self._file_datasets_dtypes[path_name] = path_object.dtype
8666

8767
def get_static_metadata(self, data_key):
88-
raise (
89-
NotImplementedError(
90-
"get_static_metadata method not yet implemented in HDFLoader class."
91-
)
92-
)
68+
if data_key not in self._static_metadata_cache:
69+
with h5py.File(self._file_path, "r") as f:
70+
value = f[data_key][()]
71+
# decode bytes to string if necessary
72+
if isinstance(value, bytes):
73+
value = value.decode("utf-8")
74+
self._static_metadata_cache[data_key] = value
75+
return self._static_metadata_cache[data_key]
76+
77+
def get_data(self, data_key: str, load_slice: ArraySlice = ...) -> np.ndarray:
78+
if data_key not in self._data_cache:
79+
with h5py.File(self._file_path, "r") as f:
80+
data_array = f[data_key][load_slice] # if load_slice is not None else f[data_key][()]
81+
self._data_cache[data_key] = np.array(data_array)
82+
return self._data_cache[data_key]
83+
84+
def get_data_shape(self, data_key: str) -> tuple[int, ...]:
85+
if data_key in self._file_datasets_shapes:
86+
return self._file_datasets_shapes[data_key]
87+
return ()
88+
89+
def get_data_dtype(self, data_key: str) -> np.dtype | None:
90+
if data_key in self._file_datasets_dtypes:
91+
return self._file_datasets_dtypes[data_key]
92+
return None
93+
94+
def get_data_attributes(self, data_key: str) -> dict[str, Any]:
95+
attributes = {}
96+
with h5py.File(self._file_path, "r") as f:
97+
if data_key in f:
98+
dataset = f[data_key]
99+
for attr_key in dataset.attrs:
100+
attributes[attr_key] = dataset.attrs[attr_key]
101+
return attributes

src/modacor/io/yaml/yaml_loader.py

Lines changed: 34 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@
2020
import numpy as np
2121
import yaml
2222

23-
from modacor.administration.licenses import BSD3Clause as __license__ # noqa: F401
2423
from modacor.dataclasses.messagehandler import MessageHandler
24+
from modacor.io.io_source import ArraySlice
2525

2626
from ..io_source import IoSource
2727

@@ -50,21 +50,26 @@ class YAMLLoader(IoSource):
5050

5151
_yaml_data: dict[str, Any] = dict()
5252
_data_cache: dict[str, np.ndarray] = None
53+
_file_path: Path | None = None
54+
_static_metadata_cache: dict[str, Any] = None
5355

54-
def __init__(self, source_reference: str, logging_level=WARNING):
56+
def __init__(self, source_reference: str, logging_level=WARNING, resource_location: Path | str | None = None):
5557
super().__init__(source_reference)
56-
self.logger = MessageHandler(level=logging_level, name="StaticMetadata")
57-
self._data_cache = {} # for values with units and uncertainties
58+
self.logger = MessageHandler(level=logging_level, name="YAMLLoader")
59+
self._file_path = Path(resource_location) if resource_location is not None else None
60+
self._file_datasets = []
61+
self._file_datasets_shapes = {}
62+
self._data_cache = {} # for values that are float
5863
self._static_metadata_cache = {} # for other elements such as strings and tags
5964

60-
def _load_from_yaml(self, file_path: Path) -> None:
65+
def _preload(self) -> None:
6166
"""
6267
Load static metadata from a YAML file.
6368
This method should be implemented to parse the YAML file and populate
6469
the _data_cache with SourceData objects.
6570
"""
66-
assert file_path.exists(), f"Static metadataa file {file_path} does not exist."
67-
with open(file_path, "r") as f:
71+
assert self._file_path.is_file(), self.logger.error(f"Static metadata file {self._file_path} does not exist.")
72+
with open(self._file_path, "r") as f:
6873
self._yaml_data.update(yaml.safe_load(f))
6974

7075
def get_static_metadata(self, data_key: str) -> Any:
@@ -75,7 +80,7 @@ def get_static_metadata(self, data_key: str) -> Any:
7580
self.logger.error(f"Static metadata key '{data_key}' not in YAML data: {e}")
7681
return None
7782

78-
def get_data(self, data_key: str) -> np.ndarray:
83+
def get_data(self, data_key: str, load_slice: ArraySlice = ...) -> np.ndarray:
7984
"""
8085
Get the data from the static metadata.
8186
"""
@@ -84,4 +89,24 @@ def get_data(self, data_key: str) -> np.ndarray:
8489
# try to convert from the yaml data into an np.asarray
8590
self._data_cache.update({data_key: self.get_static_metadata(data_key)})
8691

87-
return np.asarray(self._data_cache.get(data_key), dtype=float)
92+
return np.asarray(self._data_cache.get(data_key), dtype=float)[load_slice]
93+
94+
def get_data_shape(self, data_key: str) -> tuple[int, ...]:
95+
"""
96+
Get the shape of the data from the static metadata.
97+
"""
98+
if data_key in self._data_cache:
99+
return np.asarray(self._data_cache.get(data_key)).shape
100+
return ()
101+
102+
def get_data_dtype(self, data_key: str) -> np.dtype | None:
103+
"""
104+
Get the data type of the data from the static metadata.
105+
"""
106+
if data_key in self._data_cache:
107+
return np.asarray(self._data_cache.get(data_key)).dtype
108+
return None
109+
110+
def get_data_attributes(self, data_key):
111+
# not implemented for YAML, so just call the superclass method
112+
return super().get_data_attributes(data_key)
Lines changed: 52 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,40 +1,25 @@
11
# SPDX-License-Identifier: BSD-3-Clause
2-
# Copyright 2025 MoDaCor Authors
3-
#
4-
# Redistribution and use in source and binary forms, with or without modification,
5-
# are permitted provided that the following conditions are met:
6-
# 1. Redistributions of source code must retain the above copyright notice, this
7-
# list of conditions and the following disclaimer.
8-
# 2. Redistributions in binary form must reproduce the above copyright notice,
9-
# this list of conditions and the following disclaimer in the documentation
10-
# and/or other materials provided with the distribution.
11-
# 3. Neither the name of the copyright holder nor the names of its contributors
12-
# may be used to endorse or promote products derived from this software without
13-
# specific prior written permission.
14-
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS “AS IS” AND
15-
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
16-
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
17-
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
18-
# ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
19-
# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
20-
# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
21-
# ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
22-
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
23-
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
24-
25-
__license__ = "BSD-3-Clause"
26-
__copyright__ = "Copyright 2025 MoDaCor Authors"
27-
__status__ = "Alpha"
2+
# /usr/bin/env python3
3+
# -*- coding: utf-8 -*-
284

5+
from __future__ import annotations
6+
7+
__coding__ = "utf-8"
8+
__author__ = "Brian R. Pauw"
9+
__copyright__ = "Copyright 2025, The MoDaCor team"
10+
__date__ = "06/06/2025"
11+
__status__ = "Development" # "Development", "Production"
12+
# end of header and standard imports
2913

3014
import tempfile
3115
import unittest
3216
from os import unlink
17+
from pathlib import Path
3318

3419
import h5py
3520
import numpy as np
3621

37-
from ....io.hdf.hdf_loader import *
22+
from ....io.hdf.hdf_loader import HDFLoader
3823

3924

4025
class TestHDFLoader(unittest.TestCase):
@@ -59,16 +44,47 @@ def tearDown(self):
5944
unlink(self.temp_file_path)
6045

6146
def test_open_file(self):
62-
self.test_hdf_loader._open_file(self.temp_file_path)
63-
self.assertEqual(self.temp_file_path, self.test_hdf_loader._file_path)
47+
self.test_hdf_loader._file_path = Path(self.temp_file_path)
48+
self.test_hdf_loader._preload()
49+
self.assertEqual(Path(self.temp_file_path), self.test_hdf_loader._file_path)
6450
self.assertEqual(self.temp_dataset_name, self.test_hdf_loader._file_datasets[0])
6551
self.assertEqual(self.temp_dataset_shape, self.test_hdf_loader._file_datasets_shapes[self.temp_dataset_name])
6652

67-
def test_close_file(self):
68-
self.test_open_file()
69-
self.test_hdf_loader._close_file()
53+
def test_get_data(self):
54+
self.test_hdf_loader._file_path = Path(self.temp_file_path)
55+
self.test_hdf_loader._preload()
56+
data_array = self.test_hdf_loader.get_data(self.temp_dataset_name)
57+
self.assertTrue(isinstance(data_array, np.ndarray))
58+
self.assertEqual(self.temp_dataset_shape, data_array.shape)
59+
60+
def test_get_data_with_slice(self):
61+
self.test_hdf_loader._file_path = Path(self.temp_file_path)
62+
self.test_hdf_loader._preload()
63+
data_array = self.test_hdf_loader.get_data(self.temp_dataset_name, load_slice=(slice(0, 5), slice(None)))
64+
self.assertTrue(isinstance(data_array, np.ndarray))
65+
self.assertEqual((5, 2), data_array.shape)
66+
67+
def test_get_data_shape(self):
68+
self.test_hdf_loader._file_path = Path(self.temp_file_path)
69+
self.test_hdf_loader._preload()
70+
data_shape = self.test_hdf_loader.get_data_shape(self.temp_dataset_name)
71+
self.assertEqual(self.temp_dataset_shape, data_shape)
72+
73+
def test_get_data_dtype(self):
74+
self.test_hdf_loader._file_path = Path(self.temp_file_path)
75+
self.test_hdf_loader._preload()
76+
data_dtype = self.test_hdf_loader.get_data_dtype(self.temp_dataset_name)
77+
self.assertEqual(np.dtype("float64"), data_dtype)
78+
79+
def test_get_static_metadata(self):
80+
self.test_hdf_loader._file_path = Path(self.temp_file_path)
81+
self.test_hdf_loader._preload()
82+
static_metadata = self.test_hdf_loader.get_static_metadata(self.temp_dataset_name)
83+
self.assertTrue(isinstance(static_metadata, np.ndarray))
84+
self.assertEqual(self.temp_dataset_shape, static_metadata.shape)
7085

71-
self.assertEqual(None, self.test_hdf_loader._file_path)
72-
self.assertEqual(None, self.test_hdf_loader._file_reference)
73-
self.assertEqual([], self.test_hdf_loader._file_datasets)
74-
self.assertEqual({}, self.test_hdf_loader._file_datasets_shapes)
86+
def test_get_data_attributes(self):
87+
self.test_hdf_loader._file_path = Path(self.temp_file_path)
88+
self.test_hdf_loader._preload()
89+
data_attributes = self.test_hdf_loader.get_data_attributes(self.temp_dataset_name)
90+
self.assertEqual({}, data_attributes) # No attributes set, should return empty dict

src/modacor/tests/io/yaml/test_yaml_loader.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,15 +25,15 @@ def test_yaml_loader_initialization():
2525
"""
2626
Test the initialization of the YAMLLoader class.
2727
"""
28-
source = YAMLLoader("defaults")
29-
source._load_from_yaml(filepath)
28+
source = YAMLLoader("defaults", resource_location=filepath)
29+
source._preload()
3030
assert isinstance(source._yaml_data, dict)
3131
assert isinstance(source._data_cache, dict)
3232

3333

3434
def test_yaml_loader_get_value():
35-
source = YAMLLoader("defaults")
36-
source._load_from_yaml(filepath)
35+
source = YAMLLoader("defaults", resource_location=filepath)
36+
source._preload()
3737
# at this point, data_cache should be empty:
3838
assert source._data_cache == {}
3939
v = source.get_data("probe_properties/wavelength/value")

0 commit comments

Comments
 (0)