Skip to content

Commit b56bc55

Browse files
committed
final optimizations on tf sonification
1 parent f938c23 commit b56bc55

File tree

1 file changed

+17
-15
lines changed

1 file changed

+17
-15
lines changed

mir_eval/sonify.py

+17-15
Original file line numberDiff line numberDiff line change
@@ -133,15 +133,23 @@ def time_frequency(
133133
f"frequencies.shape={frequencies.shape} is incompatible with gram.shape={gram.shape}"
134134
)
135135

136+
padding = [0, 0]
137+
stacking = []
138+
136139
if times.min() > 0:
137140
# We need to pad a silence column on to gram at the beginning
138-
gram = np.pad(gram, ((0, 0), (1, 0)), mode="constant")
139-
times = np.vstack(([0, times.min()], times))
141+
padding[0] = 1
142+
stacking.append([0, times.min()])
143+
144+
stacking.append(times)
140145

141146
if times.max() < last_time_in_secs:
142147
# We need to pad a silence column onto gram at the end
143-
gram = np.pad(gram, ((0, 0), (0, 1)), mode="constant")
144-
times = np.vstack((times, [times.max(), last_time_in_secs]))
148+
padding[1] = 1
149+
stacking.append([times.max(), last_time_in_secs])
150+
151+
gram = np.pad(gram, ((0, 0), padding), mode="constant")
152+
times = np.vstack(stacking)
145153

146154
# Identify the time intervals that have some overlap with the duration
147155
idx = np.logical_and(times[:, 1] >= 0, times[:, 0] <= last_time_in_secs)
@@ -150,10 +158,6 @@ def time_frequency(
150158

151159
n_times = times.shape[0]
152160

153-
# Round up to ensure that the adjusted interval last time does not diverge from length
154-
# due to a loss of precision and truncation to ints.
155-
sample_intervals = np.round(times * fs).astype(int)
156-
157161
# Threshold the tfgram to remove negative values
158162
gram = np.maximum(gram, 0)
159163

@@ -164,7 +168,11 @@ def time_frequency(
164168
# the empty signal.
165169
return output
166170

167-
time_centers = np.mean(times, axis=1) * float(fs)
171+
# Discard frequencies below threshold
172+
freq_keep = np.max(gram, axis=1) >= threshold
173+
174+
gram = gram[freq_keep, :]
175+
frequencies = frequencies[freq_keep]
168176

169177
# Interpolate the values in gram over the time grid.
170178
if n_times > 1:
@@ -185,13 +193,7 @@ def time_frequency(
185193

186194
signal = interpolator(np.arange(length))
187195

188-
# Check if there is at least one element on each frequency that has a value above the threshold
189-
# to justify processing, for optimisation.
190-
spectral_max_magnitudes = np.max(gram, axis=1)
191196
for n, frequency in enumerate(frequencies):
192-
if spectral_max_magnitudes[n] < threshold:
193-
continue
194-
195197
# Get a waveform of length samples at this frequency
196198
wave = _fast_synthesize(frequency, n_dec, fs, function, length)
197199

0 commit comments

Comments
 (0)