Skip to content

Commit e70c014

Browse files
committed
Implement get_analogsignal_chunk() generically when a rawio class has_buffer_description_api=True
This should also solve the memmap and memmory leak problem.
1 parent 3083513 commit e70c014

File tree

3 files changed

+106
-41
lines changed

3 files changed

+106
-41
lines changed

neo/rawio/baserawio.py

Lines changed: 68 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,8 @@
7777

7878
from neo import logging_handler
7979

80+
from .utils import get_memmap_chunk_from_opened_file
81+
8082

8183
possible_raw_modes = [
8284
"one-file",
@@ -656,7 +658,12 @@ def get_signal_size(self, block_index: int, seg_index: int, stream_index: int |
656658
657659
"""
658660
stream_index = self._get_stream_index_from_arg(stream_index)
659-
return self._get_signal_size(block_index, seg_index, stream_index)
661+
662+
if not self.has_buffer_description_api:
663+
return self._get_signal_size(block_index, seg_index, stream_index)
664+
else:
665+
# use the buffer description
666+
return self._get_signal_size_generic(block_index, seg_index, stream_index)
660667

661668
def get_signal_t_start(self, block_index: int, seg_index: int, stream_index: int | None = None):
662669
"""
@@ -805,7 +812,11 @@ def get_analogsignal_chunk(
805812
if np.all(np.diff(channel_indexes) == 1):
806813
channel_indexes = slice(channel_indexes[0], channel_indexes[-1] + 1)
807814

808-
raw_chunk = self._get_analogsignal_chunk(block_index, seg_index, i_start, i_stop, stream_index, channel_indexes)
815+
if not self.has_buffer_description_api:
816+
raw_chunk = self._get_analogsignal_chunk(block_index, seg_index, i_start, i_stop, stream_index, channel_indexes)
817+
else:
818+
# use the buffer description
819+
raw_chunk = self._get_analogsignal_chunk_generic(block_index, seg_index, i_start, i_stop, stream_index, channel_indexes)
809820

810821
return raw_chunk
811822

@@ -1277,6 +1288,7 @@ def _get_signal_size(self, block_index: int, seg_index: int, stream_index: int):
12771288
12781289
All channels indexed must have the same size and t_start.
12791290
"""
1291+
# must NOT be implemented if has_buffer_description_api=True
12801292
raise (NotImplementedError)
12811293

12821294
def _get_signal_t_start(self, block_index: int, seg_index: int, stream_index: int):
@@ -1304,7 +1316,7 @@ def _get_analogsignal_chunk(
13041316
-------
13051317
array of samples, with each requested channel in a column
13061318
"""
1307-
1319+
# must NOT be implemented if has_buffer_description_api=True
13081320
raise (NotImplementedError)
13091321

13101322
###
@@ -1355,6 +1367,59 @@ def get_analogsignal_buffer_description(self, block_index: int = 0, seg_index: i
13551367
def _get_analogsignal_buffer_description(self, block_index, seg_index, stream_index):
13561368
raise (NotImplementedError)
13571369

1370+
def _get_signal_size_generic(self, block_index, seg_index, stream_index):
1371+
# When has_buffer_description_api=True this used to avoid to write _get_analogsignal_chunk())
1372+
1373+
buffer_desc = self._get_signal_size(block_index, seg_index, stream_index)
1374+
return buffer_desc['shape'][0]
1375+
1376+
def _get_analogsignal_chunk_generic(
1377+
self,
1378+
block_index: int,
1379+
seg_index: int,
1380+
i_start: int | None,
1381+
i_stop: int | None,
1382+
stream_index: int,
1383+
channel_indexes: list[int] | None,
1384+
):
1385+
# When has_buffer_description_api=True this used to avoid to write _get_analogsignal_chunk())
1386+
buffer_desc = self._get_analogsignal_buffer_description( block_index, seg_index, stream_index)
1387+
1388+
if buffer_desc['type'] == 'binary':
1389+
1390+
# open files on demand and keep reference to opened file
1391+
if not hasattr(self, '_memmap_analogsignal_streams'):
1392+
self._memmap_analogsignal_streams = {}
1393+
if block_index not in self._memmap_analogsignal_streams:
1394+
self._memmap_analogsignal_streams[block_index] = {}
1395+
if seg_index not in self._memmap_analogsignal_streams[block_index]:
1396+
self._memmap_analogsignal_streams[block_index][seg_index] = {}
1397+
if stream_index not in self._memmap_analogsignal_streams[block_index][seg_index]:
1398+
fid = open(buffer_desc['file_path'], mode='rb')
1399+
self._memmap_analogsignal_streams[block_index][seg_index] = fid
1400+
else:
1401+
fid = self._memmap_analogsignal_streams[block_index][seg_index]
1402+
1403+
1404+
i_start = i_start or 0
1405+
i_stop = i_stop or buffer_desc['shape'][0]
1406+
1407+
num_channels = buffer_desc['shape'][1]
1408+
1409+
raw_sigs = get_memmap_chunk_from_opened_file(fid, num_channels, i_start, i_stop, np.dtype(buffer_desc['dtype']), file_offset=buffer_desc['file_offset'])
1410+
1411+
# this is a pre slicing when the stream do not contain all channels (for instance spikeglx when load_sync_channel=False)
1412+
channel_slice = buffer_desc.get('channel_slice', None)
1413+
if channel_slice is not None:
1414+
raw_sigs = raw_sigs[:, channel_slice]
1415+
1416+
# channel slice requested
1417+
if channel_indexes is not None:
1418+
raw_sigs = raw_sigs[:, channel_indexes]
1419+
else:
1420+
raise NotImplementedError()
1421+
1422+
return raw_sigs
13581423

13591424

13601425

neo/rawio/spikeglxrawio.py

Lines changed: 37 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@
6363
_spike_channel_dtype,
6464
_event_channel_dtype,
6565
)
66-
from .utils import get_memmap_shape, get_memmap_chunk_from_open_file
66+
from .utils import get_memmap_shape
6767

6868

6969
class SpikeGLXRawIO(BaseRawIO):
@@ -127,7 +127,7 @@ def _parse_header(self):
127127
nb_segment = np.unique([info["seg_index"] for info in self.signals_info_list]).size
128128

129129

130-
self._memmaps = {}
130+
# self._memmaps = {}
131131
self.signals_info_dict = {}
132132
# on block
133133
self._buffer_descriptions = {0 :{}}
@@ -139,11 +139,11 @@ def _parse_header(self):
139139
self.signals_info_dict[key] = info
140140

141141
# create memmap
142-
data = np.memmap(info["bin_file"], dtype="int16", mode="r", offset=0, order="C")
142+
# data = np.memmap(info["bin_file"], dtype="int16", mode="r", offset=0, order="C")
143143
# this should be (info['sample_length'], info['num_chan'])
144144
# be some file are shorten
145-
data = data.reshape(-1, info["num_chan"])
146-
self._memmaps[key] = data
145+
# data = data.reshape(-1, info["num_chan"])
146+
# self._memmaps[key] = data
147147

148148
stream_index = stream_names.index(info["stream_name"])
149149
if seg_index not in self._buffer_descriptions[0]:
@@ -273,42 +273,42 @@ def _segment_t_start(self, block_index, seg_index):
273273
def _segment_t_stop(self, block_index, seg_index):
274274
return self._t_stops[seg_index]
275275

276-
def _get_signal_size(self, block_index, seg_index, stream_index):
277-
stream_id = self.header["signal_streams"][stream_index]["id"]
278-
memmap = self._memmaps[seg_index, stream_id]
279-
return int(memmap.shape[0])
276+
# def _get_signal_size(self, block_index, seg_index, stream_index):
277+
# stream_id = self.header["signal_streams"][stream_index]["id"]
278+
# memmap = self._memmaps[seg_index, stream_id]
279+
# return int(memmap.shape[0])
280280

281281
def _get_signal_t_start(self, block_index, seg_index, stream_index):
282282
return 0.0
283283

284-
def _get_analogsignal_chunk(self, block_index, seg_index, i_start, i_stop, stream_index, channel_indexes):
285-
stream_id = self.header["signal_streams"][stream_index]["id"]
286-
memmap = self._memmaps[seg_index, stream_id]
287-
stream_name = self.header["signal_streams"]["name"][stream_index]
288-
289-
# take care of sync channel
290-
info = self.signals_info_dict[0, stream_name]
291-
if not self.load_sync_channel and info["has_sync_trace"]:
292-
memmap = memmap[:, :-1]
293-
294-
# since we cut the memmap, we can simplify the channel selection
295-
if channel_indexes is None:
296-
channel_selection = slice(None)
297-
elif isinstance(channel_indexes, slice):
298-
channel_selection = channel_indexes
299-
elif not isinstance(channel_indexes, slice):
300-
if np.all(np.diff(channel_indexes) == 1):
301-
# consecutive channel then slice this avoid a copy (because of ndarray.take(...)
302-
# and so keep the underlying memmap
303-
channel_selection = slice(channel_indexes[0], channel_indexes[0] + len(channel_indexes))
304-
else:
305-
channel_selection = channel_indexes
306-
else:
307-
raise ValueError("get_analogsignal_chunk : channel_indexes" "must be slice or list or array of int")
308-
309-
raw_signals = memmap[slice(i_start, i_stop), channel_selection]
310-
311-
return raw_signals
284+
# def _get_analogsignal_chunk(self, block_index, seg_index, i_start, i_stop, stream_index, channel_indexes):
285+
# stream_id = self.header["signal_streams"][stream_index]["id"]
286+
# memmap = self._memmaps[seg_index, stream_id]
287+
# stream_name = self.header["signal_streams"]["name"][stream_index]
288+
289+
# # take care of sync channel
290+
# info = self.signals_info_dict[0, stream_name]
291+
# if not self.load_sync_channel and info["has_sync_trace"]:
292+
# memmap = memmap[:, :-1]
293+
294+
# # since we cut the memmap, we can simplify the channel selection
295+
# if channel_indexes is None:
296+
# channel_selection = slice(None)
297+
# elif isinstance(channel_indexes, slice):
298+
# channel_selection = channel_indexes
299+
# elif not isinstance(channel_indexes, slice):
300+
# if np.all(np.diff(channel_indexes) == 1):
301+
# # consecutive channel then slice this avoid a copy (because of ndarray.take(...)
302+
# # and so keep the underlying memmap
303+
# channel_selection = slice(channel_indexes[0], channel_indexes[0] + len(channel_indexes))
304+
# else:
305+
# channel_selection = channel_indexes
306+
# else:
307+
# raise ValueError("get_analogsignal_chunk : channel_indexes" "must be slice or list or array of int")
308+
309+
# raw_signals = memmap[slice(i_start, i_stop), channel_selection]
310+
311+
# return raw_signals
312312

313313
def _event_count(self, event_channel_idx, block_index=None, seg_index=None):
314314
timestamps, _, _ = self._get_event_timestamps(block_index, seg_index, event_channel_idx, None, None)

neo/rawio/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def get_memmap_shape(filename, dtype, num_channels=None, offset=0):
1717
return shape
1818

1919

20-
def get_memmap_chunk_from_open_file(fid, num_channels, start, stop, dtype, file_offset=0):
20+
def get_memmap_chunk_from_opened_file(fid, num_channels, start, stop, dtype, file_offset=0):
2121
"""
2222
Utility fonction to get a chunk as a memmap array directly from an opened file.
2323
Using this instead memmap can avoid memmory consumption when multiprocessing.

0 commit comments

Comments
 (0)