Skip to content

Commit 413c985

Browse files
committed
fix rebase mistake
1 parent 051fb84 commit 413c985

File tree

2 files changed

+54
-8
lines changed

2 files changed

+54
-8
lines changed

dl1_data_handler/image_mapper.py

Lines changed: 48 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111
import astropy.units as u
1212

1313
from ctapipe.instrument.camera import PixelShape
14-
from ctapipe.core import TelescopeComponent
15-
from ctapipe.core.traits import Bool, Int
14+
from ctapipe.core import Component
15+
from ctapipe.core.traits import Bool, Int, Float
1616

1717
__all__ = [
1818
"ImageMapper",
@@ -27,7 +27,7 @@
2727
"HexagonalPatchMapper",
2828
]
2929

30-
class ImageMapper(TelescopeComponent):
30+
class ImageMapper(Component):
3131
"""
3232
Base component for mapping raw 1D vectors into 2D mapped images.
3333
@@ -90,8 +90,23 @@ def __init__(
9090
parent : ctapipe.core.Component or ctapipe.core.Tool
9191
Parent of this component in the configuration hierarchy,
9292
this is mutually exclusive with passing ``config``
93+
**kwargs
94+
Additional keyword arguments for traitlets. Non-traitlet kwargs
95+
(like 'subarray') are filtered out for compatibility.
9396
"""
9497

98+
# Filter out non-traitlet kwargs before passing to Component
99+
# This allows compatibility with ctapipe's reader which may pass extra kwargs
100+
component_kwargs = {
101+
key: value for key, value in kwargs.items()
102+
if self.class_own_traits().get(key) is not None
103+
}
104+
105+
super().__init__(
106+
config=config,
107+
parent=parent,
108+
**component_kwargs,
109+
)
95110
# Camera types
96111
self.geometry = geometry
97112
self.camera_type = self.geometry.name
@@ -1259,6 +1274,16 @@ class RebinMapper(ImageMapper):
12591274
),
12601275
).tag(config=True)
12611276

1277+
max_memory_gb = Float(
1278+
default_value=10,
1279+
allow_none=True,
1280+
help=(
1281+
"Maximum memory in GB that RebinMapper is allowed to allocate. "
1282+
"Set to None to disable memory checks. Default is 10 GB. "
1283+
"Note: RebinMapper uses approximately (image_shape * 10)^2 * image_shape^2 * 4 bytes."
1284+
),
1285+
).tag(config=True)
1286+
12621287
def __init__(
12631288
self,
12641289
geometry,
@@ -1298,6 +1323,26 @@ def __init__(
12981323
self.image_shape = self.interpolation_image_shape
12991324
self.internal_shape = self.image_shape + self.internal_pad * 2
13001325
self.rebinning_mult_factor = 10
1326+
1327+
# Validate memory requirements before proceeding (if max_memory_gb is set)
1328+
if self.max_memory_gb is not None:
1329+
# RebinMapper uses a fine grid (internal_shape * rebinning_mult_factor)^2
1330+
# and creates a mapping matrix of shape (fine_grid_size, internal_shape, internal_shape)
1331+
fine_grid_size = (self.internal_shape * self.rebinning_mult_factor) ** 2
1332+
estimated_memory_gb = (
1333+
fine_grid_size * self.internal_shape * self.internal_shape * 4
1334+
) / (1024**3) # 4 bytes per float32
1335+
1336+
if estimated_memory_gb > self.max_memory_gb:
1337+
raise ValueError(
1338+
f"RebinMapper with image_shape={self.image_shape} would require "
1339+
f"approximately {estimated_memory_gb:.1f} GB of memory, which exceeds "
1340+
f"the limit of {self.max_memory_gb:.1f} GB. "
1341+
f"To allow this allocation, set max_memory_gb to a higher value or None. "
1342+
f"Alternatively, consider using a smaller interpolation_image_shape (recommended < 60) "
1343+
f"or use BilinearMapper or BicubicMapper instead, which are more memory-efficient."
1344+
)
1345+
13011346
# Creating the hexagonal and the output grid for the conversion methods.
13021347
input_grid, output_grid = super()._get_grids_for_interpolation()
13031348
# Calculate the mapping table

dl1_data_handler/reader.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -668,7 +668,7 @@ def _construct_mono_example_identifiers(self):
668668
self.n_signal_events = np.count_nonzero(
669669
self.example_identifiers[class_column] == 1
670670
)
671-
if self.input_url_background:
671+
if self.input_url_background or (isinstance(self, DLTriggerReader) and not self.one_class):
672672
self.n_bkg_events = np.count_nonzero(
673673
self.example_identifiers[class_column] == 0
674674
)
@@ -2206,10 +2206,11 @@ def _add_trigger_table(self, batch):
22062206
# Packing info
22072207
with tables.open_file(self.input_trigger_files[0], 'r') as h5file:
22082208
node = h5file.get_node('/table')
2209-
self._trigger_mask_shape = tuple(node._v_attrs.trigger_mask_shape)
2210-
self._trigger_mask_bits = int(node._v_attrs.trigger_mask_bits)
2211-
self._trigger_mask_packed_len = int(node._v_attrs.trigger_mask_packed_len)
2212-
self._trigger_mask_bitorder = str(node._v_attrs.trigger_mask_bitorder)
2209+
if hasattr(node._v_attrs, 'trigger_mask_shape'):
2210+
self._trigger_mask_shape = tuple(node._v_attrs.trigger_mask_shape)
2211+
self._trigger_mask_bits = int(node._v_attrs.trigger_mask_bits)
2212+
self._trigger_mask_packed_len = int(node._v_attrs.trigger_mask_packed_len)
2213+
self._trigger_mask_bitorder = str(node._v_attrs.trigger_mask_bitorder)
22132214

22142215
for file_idx, trigger_file in enumerate(self.input_trigger_files):
22152216
tdscan_table = read_table(trigger_file, "/table")

0 commit comments

Comments
 (0)