Skip to content

Commit 284b40b

Browse files
authored
Merge pull request #54 from BAMresearch/improve_io_sources
Improve io sources
2 parents 9f2fa21 + a07a18b commit 284b40b

File tree

8 files changed

+58
-224
lines changed

8 files changed

+58
-224
lines changed

src/modacor/io/hdf/hdf_loader.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from __future__ import annotations
66

7-
from typing import Any, Optional
7+
from typing import Any
88

99
__coding__ = "utf-8"
1010
__author__ = "Tim Snow, Brian R. Pauw"
@@ -35,7 +35,7 @@ class HDFLoader(IoSource):
3535
_static_metadata_cache: dict[str, Any] = None
3636

3737
def __init__(self, source_reference: str, logging_level=WARNING, resource_location: Path | str | None = None):
38-
super().__init__(source_reference)
38+
super().__init__(source_reference=source_reference)
3939
self.logger = MessageHandler(level=logging_level, name="HDFLoader")
4040
self._file_path = Path(resource_location) if resource_location is not None else None
4141
# self._file_reference = None # let's not leave open file references lying around if we can help it.
@@ -44,6 +44,7 @@ def __init__(self, source_reference: str, logging_level=WARNING, resource_locati
4444
self._file_datasets_dtypes = {}
4545
self._data_cache = {}
4646
self._static_metadata_cache = {}
47+
self._preload() # load the HDF5 file structure immediately so we have some information, but not the data
4748

4849
def _preload(self):
4950
assert self._file_path.is_file(), self.logger.error(f"HDF5 file {self._file_path} does not exist.")

src/modacor/io/io_registry.py

Lines changed: 0 additions & 48 deletions
This file was deleted.

src/modacor/io/io_source.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
from __future__ import annotations
66

7+
import attrs
8+
79
__coding__ = "utf-8"
810
__authors__ = ["Malte Storm", "Brian R. Pauw"] # add names to the list as appropriate
911
__copyright__ = "Copyright 2025, The MoDaCor team"
@@ -60,9 +62,9 @@ class IoSource:
6062
slices can be separated by double semicolon ';;'.
6163
"""
6264

63-
type_reference = "IoSource"
64-
6565
configuration: dict[str, Any] = field(factory=default_config)
66+
source_reference: str = field(default="", converter=str, validator=attrs.validators.instance_of(str))
67+
type_reference: str = "IoSource"
6668

6769
def get_data(self, data_key: str, load_slice: Optional[ArraySlice] = None) -> np.ndarray:
6870
"""

src/modacor/io/io_sources.py

Lines changed: 33 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,25 @@
11
# SPDX-License-Identifier: BSD-3-Clause
2-
# Copyright 2025 MoDaCor Authors
3-
#
4-
# Redistribution and use in source and binary forms, with or without modification,
5-
# are permitted provided that the following conditions are met:
6-
# 1. Redistributions of source code must retain the above copyright notice, this
7-
# list of conditions and the following disclaimer.
8-
# 2. Redistributions in binary form must reproduce the above copyright notice,
9-
# this list of conditions and the following disclaimer in the documentation
10-
# and/or other materials provided with the distribution.
11-
# 3. Neither the name of the copyright holder nor the names of its contributors
12-
# may be used to endorse or promote products derived from this software without
13-
# specific prior written permission.
14-
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS “AS IS” AND
15-
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
16-
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
17-
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
18-
# ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
19-
# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
20-
# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
21-
# ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
22-
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
23-
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
24-
25-
__license__ = "BSD-3-Clause"
26-
__copyright__ = "Copyright 2025 MoDaCor Authors"
27-
__status__ = "Alpha"
2+
# /usr/bin/env python3
3+
# -*- coding: utf-8 -*-
4+
5+
from __future__ import annotations
6+
7+
__coding__ = "utf-8"
8+
__author__ = "Brian R. Pauw"
9+
__copyright__ = "Copyright 2025, The MoDaCor team"
10+
__date__ = "06/06/2025"
11+
__status__ = "Development" # "Development", "Production"
12+
# end of header and standard imports
13+
2814
__all__ = ["IoSources"]
2915

3016

31-
from typing import Any
17+
from typing import Any, Optional
3218

3319
import numpy as np
3420
from attrs import define, field
3521

36-
from modacor.io.io_source import IoSource
22+
from modacor.io.io_source import ArraySlice, IoSource
3723

3824

3925
@define
@@ -47,21 +33,24 @@ class IoSources:
4733

4834
defined_sources: dict[str, IoSource] = field(factory=dict)
4935

50-
def register_source(self, source_reference: str, source: IoSource):
36+
def register_source(self, source: IoSource, source_reference: str | None = None):
5137
"""
52-
Register a new source class with the given name.
38+
Register a new source class with the given name. If no source_reference is provided, the
39+
source's own source_reference attribute will be used.
5340
5441
Parameters
5542
----------
56-
source_reference : str
57-
The reference name of the source to register.
5843
source : IoSource
5944
The class of the source to register.
45+
source_reference : str
46+
The reference name of the source to register.
6047
"""
61-
if not isinstance(source_reference, str):
62-
raise TypeError("source_name must be a string")
6348
if not isinstance(source, IoSource):
6449
raise TypeError("source_class must be a subclass of IoSource")
50+
if source_reference is None:
51+
source_reference = source.source_reference
52+
if not isinstance(source_reference, str):
53+
raise TypeError("source_name must be a string")
6554
if source_reference in self.defined_sources:
6655
raise ValueError(f"Source {source_reference} already registered.")
6756
self.defined_sources[source_reference] = source
@@ -109,7 +98,7 @@ def split_data_reference(self, data_reference: str) -> tuple[str, str]:
10998
)
11099
return _split[0], _split[1]
111100

112-
def get_data(self, data_reference: str, index: int) -> np.ndarray:
101+
def get_data(self, data_reference: str, load_slice: Optional[ArraySlice] = ...) -> np.ndarray:
113102
"""
114103
Get data from the specified source using the provided data key.
115104
@@ -120,8 +109,8 @@ def get_data(self, data_reference: str, index: int) -> np.ndarray:
120109
----------
121110
data_reference : str
122111
The reference name of the source to access.
123-
index : int
124-
The index to access the data.
112+
load_slice : Optional[ArraySlice]
113+
A slice or tuple of slices to apply to the data. If None or ellipsis, the entire data is returned.
125114
126115
Returns
127116
-------
@@ -130,9 +119,9 @@ def get_data(self, data_reference: str, index: int) -> np.ndarray:
130119
"""
131120
_source_ref, _data_key = self.split_data_reference(data_reference)
132121
_source = self.get_source(_source_ref)
133-
return _source.get_data(index, _data_key)
122+
return _source.get_data(_data_key, load_slice=load_slice)
134123

135-
def get_data_shape(self, data_reference: str, index: int) -> np.ndarray:
124+
def get_data_shape(self, data_reference: str) -> np.ndarray:
136125
"""
137126
Get data from the specified source using the provided data key.
138127
@@ -143,8 +132,6 @@ def get_data_shape(self, data_reference: str, index: int) -> np.ndarray:
143132
----------
144133
data_reference : str
145134
The reference name of the source to access.
146-
index : int
147-
The index to access the data.
148135
149136
Returns
150137
-------
@@ -153,9 +140,9 @@ def get_data_shape(self, data_reference: str, index: int) -> np.ndarray:
153140
"""
154141
_source_ref, _data_key = self.split_data_reference(data_reference)
155142
_source = self.get_source(_source_ref)
156-
return _source.get_data_shape(index, _data_key)
143+
return _source.get_data_shape(_data_key)
157144

158-
def get_data_dtype(self, data_reference: str, index: int) -> np.ndarray:
145+
def get_data_dtype(self, data_reference: str) -> np.ndarray:
159146
"""
160147
Get data from the specified source using the provided data key.
161148
@@ -166,8 +153,6 @@ def get_data_dtype(self, data_reference: str, index: int) -> np.ndarray:
166153
----------
167154
data_reference : str
168155
The reference name of the source to access.
169-
index : int
170-
The index to access the data.
171156
172157
Returns
173158
-------
@@ -176,9 +161,9 @@ def get_data_dtype(self, data_reference: str, index: int) -> np.ndarray:
176161
"""
177162
_source_ref, _data_key = self.split_data_reference(data_reference)
178163
_source = self.get_source(_source_ref)
179-
return _source.get_data_dtype(index, _data_key)
164+
return _source.get_data_dtype(_data_key)
180165

181-
def get_data_attributes(self, data_reference: str, index: int) -> np.ndarray:
166+
def get_data_attributes(self, data_reference: str) -> np.ndarray:
182167
"""
183168
Get data from the specified source using the provided data key.
184169
@@ -199,7 +184,7 @@ def get_data_attributes(self, data_reference: str, index: int) -> np.ndarray:
199184
"""
200185
_source_ref, _data_key = self.split_data_reference(data_reference)
201186
_source = self.get_source(_source_ref)
202-
return _source.get_data_attributes(index, _data_key)
187+
return _source.get_data_attributes(_data_key)
203188

204189
def get_static_metadata(self, data_reference: str) -> Any:
205190
"""

src/modacor/io/yaml/yaml_loader.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,13 +54,14 @@ class YAMLLoader(IoSource):
5454
_static_metadata_cache: dict[str, Any] = None
5555

5656
def __init__(self, source_reference: str, logging_level=WARNING, resource_location: Path | str | None = None):
57-
super().__init__(source_reference)
57+
super().__init__(source_reference=source_reference)
5858
self.logger = MessageHandler(level=logging_level, name="YAMLLoader")
5959
self._file_path = Path(resource_location) if resource_location is not None else None
6060
self._file_datasets = []
6161
self._file_datasets_shapes = {}
6262
self._data_cache = {} # for values that are float
6363
self._static_metadata_cache = {} # for other elements such as strings and tags
64+
self._preload() # load the yaml data immediately
6465

6566
def _preload(self) -> None:
6667
"""

src/modacor/tests/io/hdf/test_hdf_loader.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,18 +26,20 @@ class TestHDFLoader(unittest.TestCase):
2626
"""Testing class for modacor/io/hdf/hdf_loader.py"""
2727

2828
def setUp(self):
29-
self.test_hdf_loader = HDFLoader(source_reference="Test Data")
3029
self.temp_file_handle = tempfile.NamedTemporaryFile(delete=False, delete_on_close=False)
3130
self.temp_file_path = self.temp_file_handle.name
3231
self.temp_file_handle.close()
3332
self.temp_dataset_name = "dataset"
3433
self.temp_dataset_shape = (10, 2)
35-
self.temp_hdf_file = h5py.File(self.temp_file_path, "w")
36-
self.temp_hdf_file[self.temp_dataset_name] = np.zeros(self.temp_dataset_shape)
37-
self.temp_file_handle.close()
34+
with h5py.File(self.temp_file_path, "w") as hdf_file:
35+
hdf_file.create_dataset(
36+
self.temp_dataset_name, data=np.zeros(self.temp_dataset_shape), dtype="float64", compression="gzip"
37+
)
38+
39+
self.test_hdf_loader = HDFLoader(source_reference="Test Data", resource_location=self.temp_file_path)
3840

3941
def tearDown(self):
40-
self.test_h5_loader = None
42+
self.test_hdf_loader = None
4143
self.test_file_path = None
4244
self.test_dataset_name = None
4345
self.test_dataset_shape = None

0 commit comments

Comments
 (0)