Skip to content

Commit 416779e

Browse files
committed
ci: cache reference metrics & clean audio tests (#2335)
* cache reference metrics * audio * classif * regress * image * others * cleaning --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> (cherry picked from commit c53ea94)
1 parent 4899405 commit 416779e

File tree

103 files changed

+771
-730
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

103 files changed

+771
-730
lines changed

.gitignore

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,8 @@ pip-delete-this-directory.txt
4040
# Unit test / coverage reports
4141
tests/_data/
4242
data.zip
43+
tests/_reference-cache/
4344
htmlcov/
44-
.tox/
45-
.nox/
4645
.coverage
4746
.coverage.*
4847
.cache

tests/unittests/audio/__init__.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,30 @@
11
import os
2+
from typing import Callable, Optional
3+
4+
from torch import Tensor
25

36
from unittests import _PATH_ALL_TESTS
47

58
_SAMPLE_AUDIO_SPEECH = os.path.join(_PATH_ALL_TESTS, "_data", "audio", "audio_speech.wav")
69
_SAMPLE_AUDIO_SPEECH_BAB_DB = os.path.join(_PATH_ALL_TESTS, "_data", "audio", "audio_speech_bab_0dB.wav")
710
_SAMPLE_NUMPY_ISSUE_895 = os.path.join(_PATH_ALL_TESTS, "_data", "audio", "issue_895.npz")
11+
12+
13+
def _average_metric_wrapper(
14+
preds: Tensor, target: Tensor, metric_func: Callable, res_index: Optional[int] = None
15+
) -> Tensor:
16+
"""Average the metric values.
17+
18+
Args:
19+
preds: predictions, shape[batch, spk, time]
20+
target: targets, shape[batch, spk, time]
21+
metric_func: a function which return best_metric and best_perm
22+
res_index: if not None, return best_metric[res_index]
23+
24+
Returns:
25+
the average of best_metric
26+
27+
"""
28+
if res_index is not None:
29+
return metric_func(preds, target)[res_index].mean()
30+
return metric_func(preds, target).mean()

tests/unittests/audio/test_pesq.py

Lines changed: 6 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from torchmetrics.functional.audio import perceptual_evaluation_speech_quality
2323

2424
from unittests import _Input
25-
from unittests.audio import _SAMPLE_AUDIO_SPEECH, _SAMPLE_AUDIO_SPEECH_BAB_DB
25+
from unittests.audio import _SAMPLE_AUDIO_SPEECH, _SAMPLE_AUDIO_SPEECH_BAB_DB, _average_metric_wrapper
2626
from unittests.helpers import seed_all
2727
from unittests.helpers.testers import MetricTester
2828

@@ -41,7 +41,7 @@
4141
)
4242

4343

44-
def _pesq_original_batch(preds: Tensor, target: Tensor, fs: int, mode: str):
44+
def _reference_pesq_batch(preds: Tensor, target: Tensor, fs: int, mode: str):
4545
"""Comparison function."""
4646
# shape: preds [BATCH_SIZE, Time] , target [BATCH_SIZE, Time]
4747
# or shape: preds [NUM_BATCHES*BATCH_SIZE, Time] , target [NUM_BATCHES*BATCH_SIZE, Time]
@@ -54,23 +54,12 @@ def _pesq_original_batch(preds: Tensor, target: Tensor, fs: int, mode: str):
5454
return torch.tensor(mss)
5555

5656

57-
def _average_metric(preds, target, metric_func):
58-
# shape: preds [BATCH_SIZE, 1, Time] , target [BATCH_SIZE, 1, Time]
59-
# or shape: preds [NUM_BATCHES*BATCH_SIZE, 1, Time] , target [NUM_BATCHES*BATCH_SIZE, 1, Time]
60-
return metric_func(preds, target).mean()
61-
62-
63-
pesq_original_batch_8k_nb = partial(_pesq_original_batch, fs=8000, mode="nb")
64-
pesq_original_batch_16k_nb = partial(_pesq_original_batch, fs=16000, mode="nb")
65-
pesq_original_batch_16k_wb = partial(_pesq_original_batch, fs=16000, mode="wb")
66-
67-
6857
@pytest.mark.parametrize(
6958
"preds, target, ref_metric, fs, mode",
7059
[
71-
(inputs_8k.preds, inputs_8k.target, pesq_original_batch_8k_nb, 8000, "nb"),
72-
(inputs_16k.preds, inputs_16k.target, pesq_original_batch_16k_nb, 16000, "nb"),
73-
(inputs_16k.preds, inputs_16k.target, pesq_original_batch_16k_wb, 16000, "wb"),
60+
(inputs_8k.preds, inputs_8k.target, partial(_reference_pesq_batch, fs=8000, mode="nb"), 8000, "nb"),
61+
(inputs_16k.preds, inputs_16k.target, partial(_reference_pesq_batch, fs=16000, mode="nb"), 16000, "nb"),
62+
(inputs_16k.preds, inputs_16k.target, partial(_reference_pesq_batch, fs=16000, mode="wb"), 16000, "wb"),
7463
],
7564
)
7665
class TestPESQ(MetricTester):
@@ -89,7 +78,7 @@ def test_pesq(self, preds, target, ref_metric, fs, mode, num_processes, ddp):
8978
preds,
9079
target,
9180
PerceptualEvaluationSpeechQuality,
92-
reference_metric=partial(_average_metric, metric_func=ref_metric),
81+
reference_metric=partial(_average_metric_wrapper, metric_func=ref_metric),
9382
metric_args={"fs": fs, "mode": mode, "n_processes": num_processes},
9483
)
9584

tests/unittests/audio/test_pit.py

Lines changed: 32 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -31,27 +31,28 @@
3131
)
3232

3333
from unittests import BATCH_SIZE, NUM_BATCHES, _Input
34+
from unittests.audio import _average_metric_wrapper
3435
from unittests.helpers import seed_all
3536
from unittests.helpers.testers import MetricTester
3637

3738
seed_all(42)
3839

39-
TIME = 10
40+
TIME_FRAME = 10
4041

4142

4243
# three speaker examples to test _find_best_perm_by_linear_sum_assignment
4344
inputs1 = _Input(
44-
preds=torch.rand(NUM_BATCHES, BATCH_SIZE, 3, TIME),
45-
target=torch.rand(NUM_BATCHES, BATCH_SIZE, 3, TIME),
45+
preds=torch.rand(NUM_BATCHES, BATCH_SIZE, 3, TIME_FRAME),
46+
target=torch.rand(NUM_BATCHES, BATCH_SIZE, 3, TIME_FRAME),
4647
)
4748
# two speaker examples to test _find_best_perm_by_exhuastive_method
4849
inputs2 = _Input(
49-
preds=torch.rand(NUM_BATCHES, BATCH_SIZE, 2, TIME),
50-
target=torch.rand(NUM_BATCHES, BATCH_SIZE, 2, TIME),
50+
preds=torch.rand(NUM_BATCHES, BATCH_SIZE, 2, TIME_FRAME),
51+
target=torch.rand(NUM_BATCHES, BATCH_SIZE, 2, TIME_FRAME),
5152
)
5253

5354

54-
def naive_implementation_pit_scipy(
55+
def _reference_scipy_pit(
5556
preds: Tensor,
5657
target: Tensor,
5758
metric_func: Callable,
@@ -66,10 +67,8 @@ def naive_implementation_pit_scipy(
6667
eval_func: min or max
6768
6869
Returns:
69-
best_metric:
70-
shape [batch]
71-
best_perm:
72-
shape [batch, spk]
70+
best_metric: shape [batch]
71+
best_perm: shape [batch, spk]
7372
7473
"""
7574
batch_size, spk_num = target.shape[0:2]
@@ -88,62 +87,59 @@ def naive_implementation_pit_scipy(
8887
return torch.from_numpy(np.stack(best_metrics)), torch.from_numpy(np.stack(best_perms))
8988

9089

91-
def _average_metric(preds: Tensor, target: Tensor, metric_func: Callable) -> Tensor:
92-
"""Average the metric values.
90+
def _reference_scipy_pit_snr(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
91+
return _reference_scipy_pit(
92+
preds=preds,
93+
target=target,
94+
metric_func=signal_noise_ratio,
95+
eval_func="max",
96+
)
9397

94-
Args:
95-
preds: predictions, shape[batch, spk, time]
96-
target: targets, shape[batch, spk, time]
97-
metric_func: a function which return best_metric and best_perm
98-
99-
Returns:
100-
the average of best_metric
10198

102-
"""
103-
return metric_func(preds, target)[0].mean()
104-
105-
106-
snr_pit_scipy = partial(naive_implementation_pit_scipy, metric_func=signal_noise_ratio, eval_func="max")
107-
si_sdr_pit_scipy = partial(
108-
naive_implementation_pit_scipy, metric_func=scale_invariant_signal_distortion_ratio, eval_func="max"
109-
)
99+
def _reference_scipy_pit_si_sdr(preds: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
100+
return _reference_scipy_pit(
101+
preds=preds,
102+
target=target,
103+
metric_func=scale_invariant_signal_distortion_ratio,
104+
eval_func="max",
105+
)
110106

111107

112108
@pytest.mark.parametrize(
113109
"preds, target, ref_metric, metric_func, mode, eval_func",
114110
[
115-
(inputs1.preds, inputs1.target, snr_pit_scipy, signal_noise_ratio, "speaker-wise", "max"),
111+
(inputs1.preds, inputs1.target, _reference_scipy_pit_snr, signal_noise_ratio, "speaker-wise", "max"),
116112
(
117113
inputs1.preds,
118114
inputs1.target,
119-
si_sdr_pit_scipy,
115+
_reference_scipy_pit_si_sdr,
120116
scale_invariant_signal_distortion_ratio,
121117
"speaker-wise",
122118
"max",
123119
),
124-
(inputs2.preds, inputs2.target, snr_pit_scipy, signal_noise_ratio, "speaker-wise", "max"),
120+
(inputs2.preds, inputs2.target, _reference_scipy_pit_snr, signal_noise_ratio, "speaker-wise", "max"),
125121
(
126122
inputs2.preds,
127123
inputs2.target,
128-
si_sdr_pit_scipy,
124+
_reference_scipy_pit_si_sdr,
129125
scale_invariant_signal_distortion_ratio,
130126
"speaker-wise",
131127
"max",
132128
),
133-
(inputs1.preds, inputs1.target, snr_pit_scipy, signal_noise_ratio, "permutation-wise", "max"),
129+
(inputs1.preds, inputs1.target, _reference_scipy_pit_snr, signal_noise_ratio, "permutation-wise", "max"),
134130
(
135131
inputs1.preds,
136132
inputs1.target,
137-
si_sdr_pit_scipy,
133+
_reference_scipy_pit_si_sdr,
138134
scale_invariant_signal_distortion_ratio,
139135
"permutation-wise",
140136
"max",
141137
),
142-
(inputs2.preds, inputs2.target, snr_pit_scipy, signal_noise_ratio, "permutation-wise", "max"),
138+
(inputs2.preds, inputs2.target, _reference_scipy_pit_snr, signal_noise_ratio, "permutation-wise", "max"),
143139
(
144140
inputs2.preds,
145141
inputs2.target,
146-
si_sdr_pit_scipy,
142+
_reference_scipy_pit_si_sdr,
147143
scale_invariant_signal_distortion_ratio,
148144
"permutation-wise",
149145
"max",
@@ -163,7 +159,7 @@ def test_pit(self, preds, target, ref_metric, metric_func, mode, eval_func, ddp)
163159
preds,
164160
target,
165161
PermutationInvariantTraining,
166-
reference_metric=partial(_average_metric, metric_func=ref_metric),
162+
reference_metric=partial(_average_metric_wrapper, metric_func=ref_metric, res_index=0),
167163
metric_args={"metric_func": metric_func, "mode": mode, "eval_func": eval_func},
168164
)
169165

tests/unittests/audio/test_sa_sdr.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,9 @@
3838
)
3939

4040

41-
def _ref_metric(preds: Tensor, target: Tensor, scale_invariant: bool, zero_mean: bool):
41+
def _reference_local_sa_sdr(
42+
preds: Tensor, target: Tensor, scale_invariant: bool, zero_mean: bool, reduce_mean: bool = False
43+
):
4244
# According to the original paper, the sa-sdr equals to si-sdr with inputs concatenated over the speaker
4345
# dimension if scale_invariant==True. Accordingly, for scale_invariant==False, the sa-sdr equals to snr.
4446
# shape: preds [BATCH_SIZE, Spk, Time] , target [BATCH_SIZE, Spk, Time]
@@ -51,14 +53,14 @@ def _ref_metric(preds: Tensor, target: Tensor, scale_invariant: bool, zero_mean:
5153
preds = preds.reshape(preds.shape[0], preds.shape[1] * preds.shape[2])
5254
target = target.reshape(target.shape[0], target.shape[1] * target.shape[2])
5355
if scale_invariant:
54-
return scale_invariant_signal_distortion_ratio(preds=preds, target=target, zero_mean=False)
55-
return signal_noise_ratio(preds=preds, target=target, zero_mean=zero_mean)
56-
57-
58-
def _average_metric(preds: Tensor, target: Tensor, scale_invariant: bool, zero_mean: bool):
59-
# shape: preds [BATCH_SIZE, 1, Time] , target [BATCH_SIZE, 1, Time]
60-
# or shape: preds [NUM_BATCHES*BATCH_SIZE, 1, Time] , target [NUM_BATCHES*BATCH_SIZE, 1, Time]
61-
return _ref_metric(preds, target, scale_invariant, zero_mean).mean()
56+
sa_sdr = scale_invariant_signal_distortion_ratio(preds=preds, target=target, zero_mean=False)
57+
else:
58+
sa_sdr = signal_noise_ratio(preds=preds, target=target, zero_mean=zero_mean)
59+
if reduce_mean:
60+
# shape: preds [BATCH_SIZE, 1, Time] , target [BATCH_SIZE, 1, Time]
61+
# or shape: preds [NUM_BATCHES*BATCH_SIZE, 1, Time] , target [NUM_BATCHES*BATCH_SIZE, 1, Time]
62+
return sa_sdr.mean()
63+
return sa_sdr
6264

6365

6466
@pytest.mark.parametrize(
@@ -83,7 +85,9 @@ def test_si_sdr(self, preds, target, scale_invariant, zero_mean, ddp):
8385
preds,
8486
target,
8587
SourceAggregatedSignalDistortionRatio,
86-
reference_metric=partial(_average_metric, scale_invariant=scale_invariant, zero_mean=zero_mean),
88+
reference_metric=partial(
89+
_reference_local_sa_sdr, scale_invariant=scale_invariant, zero_mean=zero_mean, reduce_mean=True
90+
),
8791
metric_args={
8892
"scale_invariant": scale_invariant,
8993
"zero_mean": zero_mean,
@@ -96,7 +100,7 @@ def test_sa_sdr_functional(self, preds, target, scale_invariant, zero_mean):
96100
preds,
97101
target,
98102
source_aggregated_signal_distortion_ratio,
99-
reference_metric=partial(_ref_metric, scale_invariant=scale_invariant, zero_mean=zero_mean),
103+
reference_metric=partial(_reference_local_sa_sdr, scale_invariant=scale_invariant, zero_mean=zero_mean),
100104
metric_args={
101105
"scale_invariant": scale_invariant,
102106
"zero_mean": zero_mean,

tests/unittests/audio/test_sdr.py

Lines changed: 19 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
from functools import partial
15-
from typing import Callable
1615

1716
import numpy as np
1817
import pytest
@@ -43,7 +42,9 @@
4342
)
4443

4544

46-
def _sdr_original_batch(preds: Tensor, target: Tensor, compute_permutation: bool = False) -> Tensor:
45+
def _reference_sdr_batch(
46+
preds: Tensor, target: Tensor, compute_permutation: bool = False, reduce_mean: bool = False
47+
) -> Tensor:
4748
# shape: preds [BATCH_SIZE, spk, Time] , target [BATCH_SIZE, spk, Time]
4849
# or shape: preds [NUM_BATCHES*BATCH_SIZE, spk, Time] , target [NUM_BATCHES*BATCH_SIZE, spk, Time]
4950
target = target.detach().cpu().numpy()
@@ -52,56 +53,49 @@ def _sdr_original_batch(preds: Tensor, target: Tensor, compute_permutation: bool
5253
for b in range(preds.shape[0]):
5354
sdr_val_np, _, _, _ = bss_eval_sources(target[b], preds[b], compute_permutation)
5455
mss.append(sdr_val_np)
55-
return torch.tensor(np.array(mss))
56+
sdr = torch.tensor(np.array(mss))
57+
if reduce_mean:
58+
# shape: preds [BATCH_SIZE, 1, Time] , target [BATCH_SIZE, 1, Time]
59+
# or shape: preds [NUM_BATCHES*BATCH_SIZE, 1, Time] , target [NUM_BATCHES*BATCH_SIZE, 1, Time]
60+
return sdr.mean()
61+
return sdr
5662

5763

58-
def _average_metric(preds: Tensor, target: Tensor, metric_func: Callable) -> Tensor:
59-
# shape: preds [BATCH_SIZE, 1, Time] , target [BATCH_SIZE, 1, Time]
60-
# or shape: preds [NUM_BATCHES*BATCH_SIZE, 1, Time] , target [NUM_BATCHES*BATCH_SIZE, 1, Time]
61-
return metric_func(preds, target).mean()
62-
63-
64-
original_impl_compute_permutation = partial(_sdr_original_batch)
65-
66-
67-
@pytest.mark.skipif( # TODO: figure out why tests leads to cuda errors on latest torch
64+
@pytest.mark.skipif( # FIXME: figure out why tests leads to cuda errors on latest torch
6865
_TORCH_GREATER_EQUAL_1_11 and torch.cuda.is_available(), reason="tests leads to cuda errors on latest torch"
6966
)
7067
@pytest.mark.parametrize(
71-
"preds, target, ref_metric",
72-
[
73-
(inputs_1spk.preds, inputs_1spk.target, original_impl_compute_permutation),
74-
(inputs_2spk.preds, inputs_2spk.target, original_impl_compute_permutation),
75-
],
68+
"preds, target",
69+
[(inputs_1spk.preds, inputs_1spk.target), (inputs_2spk.preds, inputs_2spk.target)],
7670
)
7771
class TestSDR(MetricTester):
7872
"""Test class for `SignalDistortionRatio` metric."""
7973

8074
atol = 1e-2
8175

8276
@pytest.mark.parametrize("ddp", [pytest.param(True, marks=pytest.mark.DDP), False])
83-
def test_sdr(self, preds, target, ref_metric, ddp):
77+
def test_sdr(self, preds, target, ddp):
8478
"""Test class implementation of metric."""
8579
self.run_class_metric_test(
8680
ddp,
8781
preds,
8882
target,
8983
SignalDistortionRatio,
90-
reference_metric=partial(_average_metric, metric_func=ref_metric),
84+
reference_metric=partial(_reference_sdr_batch, reduce_mean=True),
9185
metric_args={},
9286
)
9387

94-
def test_sdr_functional(self, preds, target, ref_metric):
88+
def test_sdr_functional(self, preds, target):
9589
"""Test functional implementation of metric."""
9690
self.run_functional_metric_test(
9791
preds,
9892
target,
9993
signal_distortion_ratio,
100-
ref_metric,
94+
_reference_sdr_batch,
10195
metric_args={},
10296
)
10397

104-
def test_sdr_differentiability(self, preds, target, ref_metric):
98+
def test_sdr_differentiability(self, preds, target):
10599
"""Test the differentiability of the metric, according to its `is_differentiable` attribute."""
106100
self.run_differentiability_test(
107101
preds=preds,
@@ -110,7 +104,7 @@ def test_sdr_differentiability(self, preds, target, ref_metric):
110104
metric_args={},
111105
)
112106

113-
def test_sdr_half_cpu(self, preds, target, ref_metric):
107+
def test_sdr_half_cpu(self, preds, target):
114108
"""Test dtype support of the metric on CPU."""
115109
self.run_precision_test_cpu(
116110
preds=preds,
@@ -121,7 +115,7 @@ def test_sdr_half_cpu(self, preds, target, ref_metric):
121115
)
122116

123117
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda")
124-
def test_sdr_half_gpu(self, preds, target, ref_metric):
118+
def test_sdr_half_gpu(self, preds, target):
125119
"""Test dtype support of the metric on GPU."""
126120
self.run_precision_test_gpu(
127121
preds=preds,

0 commit comments

Comments
 (0)