3131)
3232
3333from unittests import BATCH_SIZE , NUM_BATCHES , _Input
34+ from unittests .audio import _average_metric_wrapper
3435from unittests .helpers import seed_all
3536from unittests .helpers .testers import MetricTester
3637
3738seed_all (42 )
3839
39- TIME = 10
40+ TIME_FRAME = 10
4041
4142
4243# three speaker examples to test _find_best_perm_by_linear_sum_assignment
4344inputs1 = _Input (
44- preds = torch .rand (NUM_BATCHES , BATCH_SIZE , 3 , TIME ),
45- target = torch .rand (NUM_BATCHES , BATCH_SIZE , 3 , TIME ),
45+ preds = torch .rand (NUM_BATCHES , BATCH_SIZE , 3 , TIME_FRAME ),
46+ target = torch .rand (NUM_BATCHES , BATCH_SIZE , 3 , TIME_FRAME ),
4647)
4748# two speaker examples to test _find_best_perm_by_exhuastive_method
4849inputs2 = _Input (
49- preds = torch .rand (NUM_BATCHES , BATCH_SIZE , 2 , TIME ),
50- target = torch .rand (NUM_BATCHES , BATCH_SIZE , 2 , TIME ),
50+ preds = torch .rand (NUM_BATCHES , BATCH_SIZE , 2 , TIME_FRAME ),
51+ target = torch .rand (NUM_BATCHES , BATCH_SIZE , 2 , TIME_FRAME ),
5152)
5253
5354
54- def naive_implementation_pit_scipy (
55+ def _reference_scipy_pit (
5556 preds : Tensor ,
5657 target : Tensor ,
5758 metric_func : Callable ,
@@ -66,10 +67,8 @@ def naive_implementation_pit_scipy(
6667 eval_func: min or max
6768
6869 Returns:
69- best_metric:
70- shape [batch]
71- best_perm:
72- shape [batch, spk]
70+ best_metric: shape [batch]
71+ best_perm: shape [batch, spk]
7372
7473 """
7574 batch_size , spk_num = target .shape [0 :2 ]
@@ -88,62 +87,59 @@ def naive_implementation_pit_scipy(
8887 return torch .from_numpy (np .stack (best_metrics )), torch .from_numpy (np .stack (best_perms ))
8988
9089
91- def _average_metric (preds : Tensor , target : Tensor , metric_func : Callable ) -> Tensor :
92- """Average the metric values.
90+ def _reference_scipy_pit_snr (preds : Tensor , target : Tensor ) -> Tuple [Tensor , Tensor ]:
91+ return _reference_scipy_pit (
92+ preds = preds ,
93+ target = target ,
94+ metric_func = signal_noise_ratio ,
95+ eval_func = "max" ,
96+ )
9397
94- Args:
95- preds: predictions, shape[batch, spk, time]
96- target: targets, shape[batch, spk, time]
97- metric_func: a function which return best_metric and best_perm
98-
99- Returns:
100- the average of best_metric
10198
102- """
103- return metric_func (preds , target )[0 ].mean ()
104-
105-
106- snr_pit_scipy = partial (naive_implementation_pit_scipy , metric_func = signal_noise_ratio , eval_func = "max" )
107- si_sdr_pit_scipy = partial (
108- naive_implementation_pit_scipy , metric_func = scale_invariant_signal_distortion_ratio , eval_func = "max"
109- )
99+ def _reference_scipy_pit_si_sdr (preds : Tensor , target : Tensor ) -> Tuple [Tensor , Tensor ]:
100+ return _reference_scipy_pit (
101+ preds = preds ,
102+ target = target ,
103+ metric_func = scale_invariant_signal_distortion_ratio ,
104+ eval_func = "max" ,
105+ )
110106
111107
112108@pytest .mark .parametrize (
113109 "preds, target, ref_metric, metric_func, mode, eval_func" ,
114110 [
115- (inputs1 .preds , inputs1 .target , snr_pit_scipy , signal_noise_ratio , "speaker-wise" , "max" ),
111+ (inputs1 .preds , inputs1 .target , _reference_scipy_pit_snr , signal_noise_ratio , "speaker-wise" , "max" ),
116112 (
117113 inputs1 .preds ,
118114 inputs1 .target ,
119- si_sdr_pit_scipy ,
115+ _reference_scipy_pit_si_sdr ,
120116 scale_invariant_signal_distortion_ratio ,
121117 "speaker-wise" ,
122118 "max" ,
123119 ),
124- (inputs2 .preds , inputs2 .target , snr_pit_scipy , signal_noise_ratio , "speaker-wise" , "max" ),
120+ (inputs2 .preds , inputs2 .target , _reference_scipy_pit_snr , signal_noise_ratio , "speaker-wise" , "max" ),
125121 (
126122 inputs2 .preds ,
127123 inputs2 .target ,
128- si_sdr_pit_scipy ,
124+ _reference_scipy_pit_si_sdr ,
129125 scale_invariant_signal_distortion_ratio ,
130126 "speaker-wise" ,
131127 "max" ,
132128 ),
133- (inputs1 .preds , inputs1 .target , snr_pit_scipy , signal_noise_ratio , "permutation-wise" , "max" ),
129+ (inputs1 .preds , inputs1 .target , _reference_scipy_pit_snr , signal_noise_ratio , "permutation-wise" , "max" ),
134130 (
135131 inputs1 .preds ,
136132 inputs1 .target ,
137- si_sdr_pit_scipy ,
133+ _reference_scipy_pit_si_sdr ,
138134 scale_invariant_signal_distortion_ratio ,
139135 "permutation-wise" ,
140136 "max" ,
141137 ),
142- (inputs2 .preds , inputs2 .target , snr_pit_scipy , signal_noise_ratio , "permutation-wise" , "max" ),
138+ (inputs2 .preds , inputs2 .target , _reference_scipy_pit_snr , signal_noise_ratio , "permutation-wise" , "max" ),
143139 (
144140 inputs2 .preds ,
145141 inputs2 .target ,
146- si_sdr_pit_scipy ,
142+ _reference_scipy_pit_si_sdr ,
147143 scale_invariant_signal_distortion_ratio ,
148144 "permutation-wise" ,
149145 "max" ,
@@ -163,7 +159,7 @@ def test_pit(self, preds, target, ref_metric, metric_func, mode, eval_func, ddp)
163159 preds ,
164160 target ,
165161 PermutationInvariantTraining ,
166- reference_metric = partial (_average_metric , metric_func = ref_metric ),
162+ reference_metric = partial (_average_metric_wrapper , metric_func = ref_metric , res_index = 0 ),
167163 metric_args = {"metric_func" : metric_func , "mode" : mode , "eval_func" : eval_func },
168164 )
169165
0 commit comments