Skip to content

Commit 182a7e9

Browse files
authored
Merge pull request #1307 from kjohnsen/segment-add
add `add()` function to `Container` for use on `Segment` and `Block`
2 parents 6ce00dc + 0c7bb3a commit 182a7e9

File tree

4 files changed

+73
-12
lines changed

4 files changed

+73
-12
lines changed

neo/core/container.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,25 @@ def size(self):
327327
return {name: len(getattr(self, name))
328328
for name in self._child_containers}
329329

330+
@property
331+
def _container_lookup(self):
332+
return {
333+
cls_name: getattr(self, container_name)
334+
for cls_name, container_name in zip(self._child_objects, self._child_containers)
335+
}
336+
337+
def _get_container(self, cls):
338+
if hasattr(cls, "proxy_for"):
339+
cls = cls.proxy_for
340+
return self._container_lookup[cls.__name__]
341+
342+
def add(self, *objects):
343+
"""Add a new Neo object to the Container"""
344+
for obj in objects:
345+
container = self._get_container(obj.__class__)
346+
container.append(obj)
347+
348+
330349
def filter(self, targdict=None, data=True, container=False, recursive=True,
331350
objects=None, **kwargs):
332351
"""

neo/core/group.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -134,18 +134,6 @@ def __init__(self, objects=None, name=None, description=None, file_origin=None,
134134
doc="list of Groups contained in this group"
135135
)
136136

137-
@property
138-
def _container_lookup(self):
139-
return {
140-
cls_name: getattr(self, container_name)
141-
for cls_name, container_name in zip(self._child_objects, self._child_containers)
142-
}
143-
144-
def _get_container(self, cls):
145-
if hasattr(cls, "proxy_for"):
146-
cls = cls.proxy_for
147-
return self._container_lookup[cls.__name__]
148-
149137
def add(self, *objects):
150138
"""Add a new Neo object to the Group"""
151139
for obj in objects:

neo/test/coretest/test_block.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -466,6 +466,33 @@ def test_segment_list(self):
466466
blk.segments += [Segment(), Segment()]
467467
assert len(blk.segments) == 3
468468

469+
def test_add(self):
470+
blk = self.blocks[0]
471+
new_blk = simple_block()
472+
n_groups_start = len(new_blk.groups)
473+
for group in blk.groups:
474+
assert group not in new_blk.groups
475+
new_blk.add(group)
476+
assert group in new_blk.groups
477+
assert len(new_blk.groups) == n_groups_start + len(blk.groups)
478+
479+
n_segs_start = len(new_blk.segments)
480+
for seg in blk.segments:
481+
assert seg not in new_blk.segments
482+
new_blk.add(seg)
483+
assert seg in new_blk.segments
484+
assert len(new_blk.segments) == n_segs_start + len(blk.segments)
485+
486+
# test adding multiple at once
487+
blk = self.blocks[1]
488+
n_groups_start = len(new_blk.groups)
489+
new_blk.add(*blk.groups)
490+
assert len(new_blk.groups) == n_groups_start + len(blk.groups)
491+
492+
n_segs_start = len(new_blk.segments)
493+
new_blk.add(*blk.segments)
494+
assert len(new_blk.segments) == n_segs_start + len(blk.segments)
495+
469496

470497
if __name__ == "__main__":
471498
unittest.main()

neo/test/coretest/test_segment.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -642,6 +642,33 @@ def test__deepcopy(self):
642642
for child in getattr(seg1_copy, childtype, []):
643643
self.assertEqual(id(child.segment), id(seg1_copy))
644644

645+
def test_add(self):
646+
seg = Segment()
647+
648+
reader = ExampleRawIO(filename='my_filename.fake')
649+
reader.parse_header()
650+
651+
proxy_anasig = AnalogSignalProxy(rawio=reader,
652+
stream_index=0, inner_stream_channels=None,
653+
block_index=0, seg_index=0)
654+
seg.add(proxy_anasig)
655+
assert len(seg.analogsignals) == 1
656+
657+
proxy_st = SpikeTrainProxy(rawio=reader, spike_channel_index=0,
658+
block_index=0, seg_index=0)
659+
seg.add(proxy_st)
660+
assert len(seg.spiketrains) == 1
661+
662+
proxy_event = EventProxy(rawio=reader, event_channel_index=0,
663+
block_index=0, seg_index=0)
664+
seg.add(proxy_event)
665+
assert len(seg.events) == 1
666+
667+
proxy_epoch = EpochProxy(rawio=reader, event_channel_index=1,
668+
block_index=0, seg_index=0)
669+
seg.add(proxy_epoch)
670+
assert len(seg.epochs) == 1
671+
645672

646673
if __name__ == "__main__":
647674
unittest.main()

0 commit comments

Comments
 (0)