Skip to content

Commit

Permalink
buffer description api in neorawio and xarray reference API bridge (#…
Browse files Browse the repository at this point in the history
…1513)

* Proof of concept of "buffer_description_api" and xarray reference API bridge

* Implement get_analogsignal_chunk() generically when  a rawio class has_buffer_description_api=True
This should also solve the memmap and memmory leak problem.

* wip

* test on micromed

* rebase on buffer_id

* Implement get_analogsignal_chunk() generically when  a rawio class has_buffer_description_api=True
This should also solve the memmap and memmory leak problem.

* wip

* test on micromed

* some fix

* make strema a slice of buffer and xarray api use buffer_id

* json api : winedr + winwcp

* buffer api : RawBinarySignalRawIO + RawMCSRawIO

* json api : neuroscope + openephysraw

* More reader with buffer description

* wip

* json api start hdf5 on maxwell

* doc for signal_stream signal_buffer

* Merci Zach

Co-authored-by: Zach McKenzie <[email protected]>

* Use class approach for buffer api : BaseRawWithBufferApiIO

* feedback

* Apply suggestions from code review

Co-authored-by: Heberto Mayorquin <[email protected]>
Co-authored-by: Zach McKenzie <[email protected]>

* clean

* oups

* more clean

---------

Co-authored-by: Zach McKenzie <[email protected]>
Co-authored-by: Heberto Mayorquin <[email protected]>
  • Loading branch information
3 people authored Oct 25, 2024
1 parent aa0c7fe commit 96a28af
Show file tree
Hide file tree
Showing 22 changed files with 861 additions and 368 deletions.
26 changes: 26 additions & 0 deletions doc/source/rawio.rst
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,32 @@ Read event timestamps and times
In [42]: print(ev_times)
[ 0.0317]

Signal streams and signal buffers
---------------------------------

For reading analog signals **neo.rawio** has 2 important concepts:

1. The **signal_stream** : it is a group of channels that can be read together using :func:`get_analog_signal_chunk()`.
This group of channels is guaranteed to have the same sampling rate, and the same duration per segment.
Most of the time, this group of channel is a "logical" group of channels. In short they are from the same headstage
or from the same auxiliary board.
Optionally, depending on the format, a **signal_stream** can be a slice of or an entire **signal_buffer**.

2. The **signal_buffer** : it is group of channels that share the same data layout in a file. The most simple example
is channel that can be read by a simple :func:`signals = np.memmap(file, shape=..., dtype=... , offset=...)`.
A **signal_buffer** can contain one or several **signal_stream**'s (very often it is only one).
There are two kind of formats that handle this concept:

* Formats which use :func:`np.memmap()` internally
* Formats based on hdf5

There are many formats that do not handle this concept:

* the ones that use an external python package for reading data (edf, ced, plexon2, ...)
* the ones with a complicated data layout (e.g. those where the data blocks are split without structure)

To check if a format makes use of the buffer api you can check the class attribute flag `has_buffer_description_api` of the
rawio class.



Expand Down
53 changes: 28 additions & 25 deletions neo/rawio/axonrawio.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
import numpy as np

from .baserawio import (
BaseRawIO,
BaseRawWithBufferApiIO,
_signal_channel_dtype,
_signal_stream_dtype,
_signal_buffer_dtype,
Expand All @@ -63,7 +63,7 @@
from neo.core import NeoReadWriteError


class AxonRawIO(BaseRawIO):
class AxonRawIO(BaseRawWithBufferApiIO):
"""
Class for Class for reading data from pCLAMP and AxoScope files (.abf version 1 and 2)
Expand Down Expand Up @@ -92,7 +92,7 @@ class AxonRawIO(BaseRawIO):
rawmode = "one-file"

def __init__(self, filename=""):
BaseRawIO.__init__(self)
BaseRawWithBufferApiIO.__init__(self)
self.filename = filename

def _parse_header(self):
Expand All @@ -115,8 +115,6 @@ def _parse_header(self):
head_offset = info["sections"]["DataSection"]["uBlockIndex"] * BLOCKSIZE
totalsize = info["sections"]["DataSection"]["llNumEntries"]

self._raw_data = np.memmap(self.filename, dtype=sig_dtype, mode="r", shape=(totalsize,), offset=head_offset)

# 3 possible modes
if version < 2.0:
mode = info["nOperationMode"]
Expand All @@ -142,7 +140,7 @@ def _parse_header(self):
)
else:
episode_array = np.empty(1, [("offset", "i4"), ("len", "i4")])
episode_array[0]["len"] = self._raw_data.size
episode_array[0]["len"] = totalsize
episode_array[0]["offset"] = 0

# sampling_rate
Expand All @@ -154,9 +152,14 @@ def _parse_header(self):
# one sweep = one segment
nb_segment = episode_array.size

stream_id = "0"
buffer_id = "0"

# Get raw data by segment
self._raw_signals = {}
# self._raw_signals = {}
self._t_starts = {}
self._buffer_descriptions = {0 :{}}
self._stream_buffer_slice = {stream_id : None}
pos = 0
for seg_index in range(nb_segment):
length = episode_array[seg_index]["len"]
Expand All @@ -169,7 +172,15 @@ def _parse_header(self):
if (fSynchTimeUnit != 0) and (mode == 1):
length /= fSynchTimeUnit

self._raw_signals[seg_index] = self._raw_data[pos : pos + length].reshape(-1, nbchannel)
self._buffer_descriptions[0][seg_index] = {}
self._buffer_descriptions[0][seg_index][buffer_id] = {
"type" : "raw",
"file_path" : str(self.filename),
"dtype" : str(sig_dtype),
"order": "C",
"file_offset" : head_offset + pos * sig_dtype.itemsize,
"shape" : (int(length // nbchannel), int(nbchannel)),
}
pos += length

t_start = float(episode_array[seg_index]["offset"])
Expand Down Expand Up @@ -227,17 +238,14 @@ def _parse_header(self):
offset -= info["listADCInfo"][chan_id]["fSignalOffset"]
else:
gain, offset = 1.0, 0.0
stream_id = "0"
buffer_id = "0"
signal_channels.append(
(name, str(chan_id), self._sampling_rate, sig_dtype, units, gain, offset, stream_id, buffer_id)
)

signal_channels.append((name, str(chan_id), self._sampling_rate, sig_dtype, units, gain, offset, stream_id, buffer_id))

signal_channels = np.array(signal_channels, dtype=_signal_channel_dtype)

# one unique signal stream and buffer
signal_buffers = np.array([("Signals", "0")], dtype=_signal_buffer_dtype)
signal_streams = np.array([("Signals", "0", "0")], dtype=_signal_stream_dtype)
signal_buffers = np.array([("Signals", buffer_id)], dtype=_signal_buffer_dtype)
signal_streams = np.array([("Signals", stream_id, buffer_id)], dtype=_signal_stream_dtype)

# only one events channel : tag
# In ABF timstamps are not attached too any particular segment
Expand Down Expand Up @@ -295,21 +303,16 @@ def _segment_t_start(self, block_index, seg_index):
return self._t_starts[seg_index]

def _segment_t_stop(self, block_index, seg_index):
t_stop = self._t_starts[seg_index] + self._raw_signals[seg_index].shape[0] / self._sampling_rate
sig_size = self.get_signal_size(block_index, seg_index, 0)
t_stop = self._t_starts[seg_index] + sig_size / self._sampling_rate
return t_stop

def _get_signal_size(self, block_index, seg_index, stream_index):
shape = self._raw_signals[seg_index].shape
return shape[0]

def _get_signal_t_start(self, block_index, seg_index, stream_index):
return self._t_starts[seg_index]

def _get_analogsignal_chunk(self, block_index, seg_index, i_start, i_stop, stream_index, channel_indexes):
if channel_indexes is None:
channel_indexes = slice(None)
raw_signals = self._raw_signals[seg_index][slice(i_start, i_stop), channel_indexes]
return raw_signals
def _get_analogsignal_buffer_description(self, block_index, seg_index, buffer_id):
return self._buffer_descriptions[block_index][seg_index][buffer_id]


def _event_count(self, block_index, seg_index, event_channel_index):
return self._raw_ev_timestamps.size
Expand Down
158 changes: 157 additions & 1 deletion neo/rawio/baserawio.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@

from neo import logging_handler

from .utils import get_memmap_chunk_from_opened_file


possible_raw_modes = [
"one-file",
Expand Down Expand Up @@ -182,6 +184,15 @@ def __init__(self, use_cache: bool = False, cache_path: str = "same_as_resource"
self.header = None
self.is_header_parsed = False

self._has_buffer_description_api = False

def has_buffer_description_api(self) -> bool:
"""
Return if the reader handle the buffer API.
If True then the reader support internally `get_analogsignal_buffer_description()`
"""
return self._has_buffer_description_api

def parse_header(self):
"""
Parses the header of the file(s) to allow for faster computations
Expand All @@ -191,6 +202,7 @@ def parse_header(self):
# this must create
# self.header['nb_block']
# self.header['nb_segment']
# self.header['signal_buffers']
# self.header['signal_streams']
# self.header['signal_channels']
# self.header['spike_channels']
Expand Down Expand Up @@ -663,6 +675,7 @@ def get_signal_size(self, block_index: int, seg_index: int, stream_index: int |
"""
stream_index = self._get_stream_index_from_arg(stream_index)

return self._get_signal_size(block_index, seg_index, stream_index)

def get_signal_t_start(self, block_index: int, seg_index: int, stream_index: int | None = None):
Expand Down Expand Up @@ -1311,7 +1324,6 @@ def _get_analogsignal_chunk(
-------
array of samples, with each requested channel in a column
"""

raise (NotImplementedError)

###
Expand Down Expand Up @@ -1350,6 +1362,150 @@ def _rescale_event_timestamp(self, event_timestamps: np.ndarray, dtype: np.dtype
def _rescale_epoch_duration(self, raw_duration: np.ndarray, dtype: np.dtype):
raise (NotImplementedError)

###
# buffer api zone
# must be implemented if has_buffer_description_api=True
def get_analogsignal_buffer_description(self, block_index: int = 0, seg_index: int = 0, buffer_id: str = None):
if not self.has_buffer_description_api:
raise ValueError("This reader do not support buffer_description API")
descr = self._get_analogsignal_buffer_description(block_index, seg_index, buffer_id)
return descr

def _get_analogsignal_buffer_description(self, block_index, seg_index, buffer_id):
raise (NotImplementedError)



class BaseRawWithBufferApiIO(BaseRawIO):
"""
Generic class for reader that support "buffer api".
In short reader that are internally based on:
* np.memmap
* hdf5
In theses cases _get_signal_size and _get_analogsignal_chunk are totaly generic and do not need to be implemented in the class.
For this class sub classes must implements theses two dict:
* self._buffer_descriptions[block_index][seg_index] = buffer_description
* self._stream_buffer_slice[buffer_id] = None or slicer o indices
"""

def __init__(self, *arg, **kwargs):
super().__init__(*arg, **kwargs)
self._has_buffer_description_api = True

def _get_signal_size(self, block_index, seg_index, stream_index):
buffer_id = self.header["signal_streams"][stream_index]["buffer_id"]
buffer_desc = self.get_analogsignal_buffer_description(block_index, seg_index, buffer_id)
# some hdf5 revert teh buffer
time_axis = buffer_desc.get("time_axis", 0)
return buffer_desc['shape'][time_axis]

def _get_analogsignal_chunk(
self,
block_index: int,
seg_index: int,
i_start: int | None,
i_stop: int | None,
stream_index: int,
channel_indexes: list[int] | None,
):

stream_id = self.header["signal_streams"][stream_index]["id"]
buffer_id = self.header["signal_streams"][stream_index]["buffer_id"]

buffer_slice = self._stream_buffer_slice[stream_id]


buffer_desc = self.get_analogsignal_buffer_description(block_index, seg_index, buffer_id)

i_start = i_start or 0
i_stop = i_stop or buffer_desc['shape'][0]

if buffer_desc['type'] == "raw":

# open files on demand and keep reference to opened file
if not hasattr(self, '_memmap_analogsignal_buffers'):
self._memmap_analogsignal_buffers = {}
if block_index not in self._memmap_analogsignal_buffers:
self._memmap_analogsignal_buffers[block_index] = {}
if seg_index not in self._memmap_analogsignal_buffers[block_index]:
self._memmap_analogsignal_buffers[block_index][seg_index] = {}
if buffer_id not in self._memmap_analogsignal_buffers[block_index][seg_index]:
fid = open(buffer_desc['file_path'], mode='rb')
self._memmap_analogsignal_buffers[block_index][seg_index][buffer_id] = fid
else:
fid = self._memmap_analogsignal_buffers[block_index][seg_index][buffer_id]

num_channels = buffer_desc['shape'][1]

raw_sigs = get_memmap_chunk_from_opened_file(fid, num_channels, i_start, i_stop, np.dtype(buffer_desc['dtype']), file_offset=buffer_desc['file_offset'])


elif buffer_desc['type'] == 'hdf5':

# open files on demand and keep reference to opened file
if not hasattr(self, '_hdf5_analogsignal_buffers'):
self._hdf5_analogsignal_buffers = {}
if block_index not in self._hdf5_analogsignal_buffers:
self._hdf5_analogsignal_buffers[block_index] = {}
if seg_index not in self._hdf5_analogsignal_buffers[block_index]:
self._hdf5_analogsignal_buffers[block_index][seg_index] = {}
if buffer_id not in self._hdf5_analogsignal_buffers[block_index][seg_index]:
import h5py
h5file = h5py.File(buffer_desc['file_path'], mode="r")
self._hdf5_analogsignal_buffers[block_index][seg_index][buffer_id] = h5file
else:
h5file = self._hdf5_analogsignal_buffers[block_index][seg_index][buffer_id]

hdf5_path = buffer_desc["hdf5_path"]
full_raw_sigs = h5file[hdf5_path]

time_axis = buffer_desc.get("time_axis", 0)
if time_axis == 0:
raw_sigs = full_raw_sigs[i_start:i_stop, :]
elif time_axis == 1:
raw_sigs = full_raw_sigs[:, i_start:i_stop].T
else:
raise RuntimeError("Should never happen")

if buffer_slice is not None:
raw_sigs = raw_sigs[:, buffer_slice]



else:
raise NotImplementedError()

# this is a pre slicing when the stream do not contain all channels (for instance spikeglx when load_sync_channel=False)
if buffer_slice is not None:
raw_sigs = raw_sigs[:, buffer_slice]

# channel slice requested
if channel_indexes is not None:
raw_sigs = raw_sigs[:, channel_indexes]


return raw_sigs

def __del__(self):
if hasattr(self, '_memmap_analogsignal_buffers'):
for block_index in self._memmap_analogsignal_buffers.keys():
for seg_index in self._memmap_analogsignal_buffers[block_index].keys():
for buffer_id, fid in self._memmap_analogsignal_buffers[block_index][seg_index].items():
fid.close()
del self._memmap_analogsignal_buffers

if hasattr(self, '_hdf5_analogsignal_buffers'):
for block_index in self._hdf5_analogsignal_buffers.keys():
for seg_index in self._hdf5_analogsignal_buffers[block_index].keys():
for buffer_id, h5_file in self._hdf5_analogsignal_buffers[block_index][seg_index].items():
h5_file.close()
del self._hdf5_analogsignal_buffers


def pprint_vector(vector, lim: int = 8):
vector = np.asarray(vector)
Expand Down
13 changes: 8 additions & 5 deletions neo/rawio/bci2000rawio.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
"""
BCI2000RawIO is a class to read BCI2000 .dat files.
https://www.bci2000.org/mediawiki/index.php/Technical_Reference:BCI2000_File_Format
Note : BCI2000RawIO cannot implemented using has_buffer_description_api because the buffer
is not compact. The buffer of signals is not compact (has some interleaved state uint in between channels)
"""

import numpy as np
Expand Down Expand Up @@ -50,9 +53,11 @@ def _parse_header(self):
self.header["nb_block"] = 1
self.header["nb_segment"] = [1]

# one unique stream and buffer
signal_buffers = np.array(("Signals", "0"), dtype=_signal_buffer_dtype)
signal_streams = np.array([("Signals", "0", "0")], dtype=_signal_stream_dtype)
# one unique stream but no buffer because channels are not compact
stream_id = "0"
buffer_id = ""
signal_buffers = np.array([], dtype=_signal_buffer_dtype)
signal_streams = np.array([("Signals", stream_id, buffer_id)], dtype=_signal_stream_dtype)
self.header["signal_buffers"] = signal_buffers
self.header["signal_streams"] = signal_streams

Expand Down Expand Up @@ -80,8 +85,6 @@ def _parse_header(self):
if isinstance(offset, str):
offset = float(offset)

stream_id = "0"
buffer_id = "0"
sig_channels.append((ch_name, chan_id, sr, dtype, units, gain, offset, stream_id, buffer_id))
self.header["signal_channels"] = np.array(sig_channels, dtype=_signal_channel_dtype)

Expand Down
Loading

0 comments on commit 96a28af

Please sign in to comment.