Skip to content

Commit cf4c7fb

Browse files
committed
added strict monotonicity flag for hierarchical segmentation metrics
1 parent a0a8672 commit cf4c7fb

File tree

3 files changed

+201
-24
lines changed

3 files changed

+201
-24
lines changed

.gitignore

+3-1
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,10 @@ Thumbs.db
3838
# Vim
3939
*.swp
4040

41-
# pycharm
41+
# IDEs
4242
.idea/*
43+
.vscode/*
44+
.vscode/*
4345

4446
# docs
4547
docs/_build/*

mir_eval/hierarchy.py

+91-17
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ def _align_intervals(int_hier, lab_hier, t_min=0.0, t_max=None):
132132
]
133133

134134

135-
def _lca(intervals_hier, frame_size):
135+
def _lca(intervals_hier, frame_size, strict_mono=False):
136136
"""Compute the (sparse) least-common-ancestor (LCA) matrix for a
137137
hierarchical segmentation.
138138
@@ -147,6 +147,10 @@ def _lca(intervals_hier, frame_size):
147147
The list is assumed to be ordered by increasing specificity (depth).
148148
frame_size : number
149149
The length of the sample frames (in seconds)
150+
strict_mono : bool, optional
151+
If True, enforce monotonic updates for the LCA matrix. Only positions that were set to
152+
the previous level (i.e., equal to level - 1) will be updated to the current level.
153+
If False, the current level is applied unconditionally. Default is False.
150154
151155
Returns
152156
-------
@@ -170,18 +174,25 @@ def _lca(intervals_hier, frame_size):
170174
int
171175
):
172176
idx = slice(ival[0], ival[1])
173-
lca_matrix[idx, idx] = level
177+
if level == 1 or not strict_mono:
178+
lca_matrix[idx, idx] = level
179+
else:
180+
# Check if the segments' parents have matching labeling
181+
current_meet = lca_matrix[idx, idx].toarray()
182+
matching_parents_mask = current_meet == level - 1
183+
# Update only at positions where the previous level also matches
184+
current_meet[matching_parents_mask] = level
185+
lca_matrix[idx, idx] = current_meet
174186

175187
return lca_matrix.tocsr()
176188

177189

178-
def _meet(intervals_hier, labels_hier, frame_size):
179-
"""Compute the (sparse) least-common-ancestor (LCA) matrix for a
190+
def _meet(intervals_hier, labels_hier, frame_size, strict_mono=False):
191+
"""Compute the (sparse) annotation meet matrix for a
180192
hierarchical segmentation.
181193
182-
For any pair of frames ``(s, t)``, the LCA is the deepest level in
183-
the hierarchy such that ``(s, t)`` are contained within a single
184-
segment at that level.
194+
For any pair of frames ``(s, t)``, the annotation meet matrix is the deepest level
195+
in the hierarchy such that ``(s, t)`` receive the same segment label, i.e. they meet.
185196
186197
Parameters
187198
----------
@@ -193,6 +204,10 @@ def _meet(intervals_hier, labels_hier, frame_size):
193204
``i``th layer of the annotations
194205
frame_size : number
195206
The length of the sample frames (in seconds)
207+
strict_mono : bool, optional
208+
If True, enforce monotonic updates for the LCA matrix. Only positions that were set to
209+
the previous level (i.e., equal to level - 1) will be updated to the current level.
210+
If False, the current level is applied unconditionally. Default is False.
196211
197212
Returns
198213
-------
@@ -225,9 +240,21 @@ def _meet(intervals_hier, labels_hier, frame_size):
225240
for seg_i, seg_j in zip(*np.where(int_agree)):
226241
idx_i = slice(*list(int_frames[seg_i]))
227242
idx_j = slice(*list(int_frames[seg_j]))
228-
meet_matrix[idx_i, idx_j] = level
229-
if seg_i != seg_j:
230-
meet_matrix[idx_j, idx_i] = level
243+
244+
if level == 1 or not strict_mono:
245+
meet_matrix[idx_i, idx_j] = level
246+
if seg_i != seg_j:
247+
meet_matrix[idx_j, idx_i] = level
248+
249+
else:
250+
# Extract current submatrix and update elementwise
251+
current_meet = meet_matrix[idx_i, idx_j].toarray()
252+
mask = current_meet == (level - 1)
253+
current_meet[mask] = level
254+
meet_matrix[idx_i, idx_j] = current_meet
255+
256+
if seg_i != seg_j:
257+
meet_matrix[idx_j, idx_i] = current_meet
231258

232259
return scipy.sparse.csr_matrix(meet_matrix)
233260

@@ -446,21 +473,56 @@ def validate_hier_intervals(intervals_hier):
446473
# Synthesize a label array for the top layer.
447474
label_top = util.generate_labels(intervals_hier[0])
448475

449-
boundaries = set(util.intervals_to_boundaries(intervals_hier[0]))
450-
451-
for level, intervals in enumerate(intervals_hier[1:], 1):
476+
for intervals in intervals_hier[1:]:
452477
# Make sure this level is consistent with the root
453478
label_current = util.generate_labels(intervals)
454479
validate_structure(intervals_hier[0], label_top, intervals, label_current)
455480

481+
check_monotonic_boundaries(intervals_hier)
482+
483+
484+
def check_monotonic_boundaries(intervals_hier):
485+
"""Check that a hierarchical annotation has monotnoic boundaries.
486+
487+
Parameters
488+
----------
489+
intervals_hier : ordered list of segmentations
490+
491+
Returns
492+
-------
493+
bool
494+
True if the annotation has monotnoic boundaries, False otherwise
495+
"""
496+
result = True
497+
boundaries = set(util.intervals_to_boundaries(intervals_hier[0]))
498+
499+
for level, intervals in enumerate(intervals_hier[1:], 1):
456500
# Make sure all previous boundaries are accounted for
457501
new_bounds = set(util.intervals_to_boundaries(intervals))
458502

459503
if boundaries - new_bounds:
460504
warnings.warn(
461505
"Segment hierarchy is inconsistent " "at level {:d}".format(level)
462506
)
507+
result = False
463508
boundaries |= new_bounds
509+
return result
510+
511+
512+
def check_monotonic_labels(intervals_hier):
513+
"""Check that a hierarchical annotation has monotnoic labels.
514+
515+
Parameters
516+
----------
517+
intervals_hier : ordered list of segmentations
518+
519+
Returns
520+
-------
521+
bool
522+
True if the annotation has monotnoic labels, False otherwise
523+
"""
524+
## TODO Check if the monotonic anno meet mat and the max depth meet mat is the same.
525+
return True
464526

465527

466528
def tmeasure(
@@ -470,6 +532,7 @@ def tmeasure(
470532
window=15.0,
471533
frame_size=0.1,
472534
beta=1.0,
535+
strict_mono=False,
473536
):
474537
"""Compute the tree measures for hierarchical segment annotations.
475538
@@ -533,8 +596,8 @@ def tmeasure(
533596
validate_hier_intervals(estimated_intervals_hier)
534597

535598
# Build the least common ancestor matrices
536-
ref_lca = _lca(reference_intervals_hier, frame_size)
537-
est_lca = _lca(estimated_intervals_hier, frame_size)
599+
ref_lca = _lca(reference_intervals_hier, frame_size, strict_mono=strict_mono)
600+
est_lca = _lca(estimated_intervals_hier, frame_size, strict_mono=strict_mono)
538601

539602
# Compute precision and recall
540603
t_recall = _gauc(ref_lca, est_lca, transitive, window_frames)
@@ -552,6 +615,7 @@ def lmeasure(
552615
estimated_labels_hier,
553616
frame_size=0.1,
554617
beta=1.0,
618+
strict_mono=False,
555619
):
556620
"""Compute the tree measures for hierarchical segment annotations.
557621
@@ -604,8 +668,18 @@ def lmeasure(
604668
validate_hier_intervals(estimated_intervals_hier)
605669

606670
# Build the least common ancestor matrices
607-
ref_meet = _meet(reference_intervals_hier, reference_labels_hier, frame_size)
608-
est_meet = _meet(estimated_intervals_hier, estimated_labels_hier, frame_size)
671+
ref_meet = _meet(
672+
reference_intervals_hier,
673+
reference_labels_hier,
674+
frame_size,
675+
strict_mono=strict_mono,
676+
)
677+
est_meet = _meet(
678+
estimated_intervals_hier,
679+
estimated_labels_hier,
680+
frame_size,
681+
strict_mono=strict_mono,
682+
)
609683

610684
# Compute precision and recall
611685
l_recall = _gauc(ref_meet, est_meet, True, None)

tests/test_hierarchy.py

+107-6
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@
1818

1919
@pytest.mark.parametrize("window", [5, 10, 15, 30, 90, None])
2020
@pytest.mark.parametrize("frame_size", [0.1, 0.5, 1.0])
21-
def test_tmeasure_pass(window, frame_size):
21+
@pytest.mark.parametrize("strict_mono", [True, False])
22+
def test_tmeasure_pass(window, frame_size, strict_mono):
2223
# The estimate here gets none of the structure correct.
2324
ref = [[[0, 30]], [[0, 15], [15, 30]]]
2425
# convert to arrays
@@ -27,13 +28,17 @@ def test_tmeasure_pass(window, frame_size):
2728
est = ref[:1]
2829

2930
# The estimate should get 0 score here
30-
scores = mir_eval.hierarchy.tmeasure(ref, est, window=window, frame_size=frame_size)
31+
scores = mir_eval.hierarchy.tmeasure(
32+
ref, est, window=window, frame_size=frame_size, strict_mono=strict_mono
33+
)
3134

3235
for k in scores:
3336
assert k == 0.0
3437

3538
# The reference should get a perfect score here
36-
scores = mir_eval.hierarchy.tmeasure(ref, ref, window=window, frame_size=frame_size)
39+
scores = mir_eval.hierarchy.tmeasure(
40+
ref, ref, window=window, frame_size=frame_size, strict_mono=strict_mono
41+
)
3742

3843
for k in scores:
3944
assert k == 1.0
@@ -91,7 +96,8 @@ def test_tmeasure_fail_frame_size(window, frame_size):
9196

9297

9398
@pytest.mark.parametrize("frame_size", [0.1, 0.5, 1.0])
94-
def test_lmeasure_pass(frame_size):
99+
@pytest.mark.parametrize("strict_mono", [True, False])
100+
def test_lmeasure_pass(frame_size, strict_mono):
95101
# The estimate here gets none of the structure correct.
96102
ref = [[[0, 30]], [[0, 15], [15, 30]]]
97103
ref_lab = [["A"], ["a", "b"]]
@@ -104,15 +110,15 @@ def test_lmeasure_pass(frame_size):
104110

105111
# The estimate should get 0 score here
106112
scores = mir_eval.hierarchy.lmeasure(
107-
ref, ref_lab, est, est_lab, frame_size=frame_size
113+
ref, ref_lab, est, est_lab, frame_size=frame_size, strict_mono=strict_mono
108114
)
109115

110116
for k in scores:
111117
assert k == 0.0
112118

113119
# The reference should get a perfect score here
114120
scores = mir_eval.hierarchy.lmeasure(
115-
ref, ref_lab, ref, ref_lab, frame_size=frame_size
121+
ref, ref_lab, ref, ref_lab, frame_size=frame_size, strict_mono=strict_mono
116122
)
117123

118124
for k in scores:
@@ -286,6 +292,101 @@ def test_meet():
286292
assert np.all(meet == meet_truth)
287293

288294

295+
def test_strict_mono():
296+
frame_size = 1
297+
int_hier = [
298+
np.array([[0, 10]]),
299+
np.array([[0, 6], [6, 10]]),
300+
np.array([[0, 2], [2, 4], [4, 8], [8, 10]]),
301+
]
302+
303+
lab_hier = [["X"], ["A", "B"], ["a", "b", "c", "b"]]
304+
305+
# Target output
306+
meet_truth = np.asarray(
307+
[
308+
[3, 3, 2, 2, 2, 2, 1, 1, 1, 1], # (XAa)
309+
[3, 3, 2, 2, 2, 2, 1, 1, 1, 1], # (XAa)
310+
[2, 2, 3, 3, 2, 2, 1, 1, 3, 3], # (XAb)
311+
[2, 2, 3, 3, 2, 2, 1, 1, 3, 3], # (XAb)
312+
[2, 2, 2, 2, 3, 3, 3, 3, 1, 1], # (XAc)
313+
[2, 2, 2, 2, 3, 3, 3, 3, 1, 1], # (XAc)
314+
[1, 1, 1, 1, 3, 3, 3, 3, 2, 2], # (XBc)
315+
[1, 1, 1, 1, 3, 3, 3, 3, 2, 2], # (XBc)
316+
[1, 1, 3, 3, 1, 1, 2, 2, 3, 3], # (XBb)
317+
[1, 1, 3, 3, 1, 1, 2, 2, 3, 3], # (XBb)
318+
]
319+
)
320+
meet_truth_strict_mono = np.asarray(
321+
[
322+
[3, 3, 2, 2, 2, 2, 1, 1, 1, 1], # (XAa)
323+
[3, 3, 2, 2, 2, 2, 1, 1, 1, 1], # (XAa)
324+
[2, 2, 3, 3, 2, 2, 1, 1, 1, 1], # (XAb)
325+
[2, 2, 3, 3, 2, 2, 1, 1, 1, 1], # (XAb)
326+
[2, 2, 2, 2, 3, 3, 1, 1, 1, 1], # (XAc)
327+
[2, 2, 2, 2, 3, 3, 1, 1, 1, 1], # (XAc)
328+
[1, 1, 1, 1, 1, 1, 3, 3, 2, 2], # (XBc)
329+
[1, 1, 1, 1, 1, 1, 3, 3, 2, 2], # (XBc)
330+
[1, 1, 1, 1, 1, 1, 2, 2, 3, 3], # (XBb)
331+
[1, 1, 1, 1, 1, 1, 2, 2, 3, 3], # (XBb)
332+
]
333+
)
334+
lca_truth = np.asarray(
335+
[
336+
[3, 3, 2, 2, 2, 2, 1, 1, 1, 1], # (XAa)
337+
[3, 3, 2, 2, 2, 2, 1, 1, 1, 1], # (XAa)
338+
[2, 2, 3, 3, 2, 2, 1, 1, 1, 1], # (XAb)
339+
[2, 2, 3, 3, 2, 2, 1, 1, 1, 1], # (XAb)
340+
[2, 2, 2, 2, 3, 3, 3, 3, 1, 1], # (XAc)
341+
[2, 2, 2, 2, 3, 3, 3, 3, 1, 1], # (XAc)
342+
[1, 1, 1, 1, 3, 3, 3, 3, 2, 2], # (XBc)
343+
[1, 1, 1, 1, 3, 3, 3, 3, 2, 2], # (XBc)
344+
[1, 1, 1, 1, 1, 1, 2, 2, 3, 3], # (XBd)
345+
[1, 1, 1, 1, 1, 1, 2, 2, 3, 3], # (XBd)
346+
]
347+
)
348+
lca_truth_strict_mono = np.asarray(
349+
[
350+
[3, 3, 2, 2, 2, 2, 1, 1, 1, 1], # (XAa)
351+
[3, 3, 2, 2, 2, 2, 1, 1, 1, 1], # (XAa)
352+
[2, 2, 3, 3, 2, 2, 1, 1, 1, 1], # (XAb)
353+
[2, 2, 3, 3, 2, 2, 1, 1, 1, 1], # (XAb)
354+
[2, 2, 2, 2, 3, 3, 1, 1, 1, 1], # (XAc)
355+
[2, 2, 2, 2, 3, 3, 1, 1, 1, 1], # (XAc)
356+
[1, 1, 1, 1, 1, 1, 3, 3, 2, 2], # (XBc)
357+
[1, 1, 1, 1, 1, 1, 3, 3, 2, 2], # (XBc)
358+
[1, 1, 1, 1, 1, 1, 2, 2, 3, 3], # (XBd)
359+
[1, 1, 1, 1, 1, 1, 2, 2, 3, 3], # (XBd)
360+
]
361+
)
362+
363+
meet = mir_eval.hierarchy._meet(int_hier, lab_hier, frame_size, strict_mono=False)
364+
meet_strict_mono = mir_eval.hierarchy._meet(
365+
int_hier, lab_hier, frame_size, strict_mono=True
366+
)
367+
lca = mir_eval.hierarchy._lca(int_hier, frame_size, strict_mono=False)
368+
lca_strict_mono = mir_eval.hierarchy._lca(int_hier, frame_size, strict_mono=True)
369+
# Is it the right type?
370+
assert isinstance(meet, scipy.sparse.csr_matrix)
371+
meet = meet.toarray()
372+
meet_strict_mono = meet_strict_mono.toarray()
373+
assert isinstance(lca_strict_mono, scipy.sparse.csr_matrix)
374+
lca = lca.toarray()
375+
lca_strict_mono = lca_strict_mono.toarray()
376+
377+
# Does it have the right shape?
378+
assert meet.shape == (10, 10)
379+
assert meet_strict_mono.shape == (10, 10)
380+
assert lca.shape == (10, 10)
381+
assert lca_strict_mono.shape == (10, 10)
382+
383+
# Does it have the right value?
384+
assert np.all(meet == meet_truth)
385+
assert np.all(meet_strict_mono == meet_truth_strict_mono)
386+
assert np.all(lca == lca_truth)
387+
assert np.all(lca_strict_mono == lca_truth_strict_mono)
388+
389+
289390
def test_compare_frame_rankings():
290391
# number of pairs (i, j)
291392
# where ref[i] < ref[j] and est[i] >= est[j]

0 commit comments

Comments
 (0)