Skip to content

Commit 9e87b0d

Browse files
committed
engine: cleanup/improve/fix beat tracking code (#1881)
1 parent 77615ce commit 9e87b0d

File tree

2 files changed

+102
-66
lines changed

2 files changed

+102
-66
lines changed

engine/audio/src/beattracking.cpp

Lines changed: 85 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -33,22 +33,24 @@ const double filterCoeffB[] = { 0.15998789, 0.31997577, 0.15998789 };
3333

3434
BeatTracking::BeatTracking(int channels, QObject *parent)
3535
: QObject(parent)
36+
, m_channels(channels)
37+
, m_sampleRate(BEAT_DEFAULT_SAMPLE_RATE)
38+
, m_windowSize(BEAT_DEFAULT_WINDOW_SIZE)
39+
, m_hopSize(BEAT_DEFAULT_HOP_SIZE)
40+
, m_onsetWindowSize(ONSET_WINDOW_SIZE)
41+
, m_lastLag(0.0)
42+
, m_consistencyCount(0)
43+
, m_blockPosition(-1)
44+
, m_identifiedLag(0.0)
45+
, m_currentBPM(120)
46+
, m_currentMs(500)
47+
, m_silenceGateThreshold(0.001) // ≈ -60 dBFS on normalized [-1,1] audio
3648
{
37-
m_channels = channels;
38-
m_sampleRate = BEAT_DEFAULT_SAMPLE_RATE;
39-
m_windowSize = BEAT_DEFAULT_WINDOW_SIZE;
40-
m_hopSize = BEAT_DEFAULT_HOP_SIZE;
41-
m_onsetWindowSize = ONSET_WINDOW_SIZE;
42-
43-
m_currentBPM = 120;
44-
m_currentMs = 500;
45-
4649
m_windowWeights = calculateWindowWeights(m_windowSize);
47-
m_onsetWeights = calculateWindowWeights(15);
4850

49-
m_identifiedLag = 0;
50-
m_lastLag = 0;
51-
m_consistencyCount = 0;
51+
m_prevMagnitudes.resize(m_windowSize / 2);
52+
m_prevMagnitudes.fill(0.0);
53+
5254
double targetLag = (m_sampleRate * 60.0) / (RAILEIGH_TARGET_BPM*m_hopSize);
5355
m_continuityDerivation = targetLag/8;
5456

@@ -83,9 +85,9 @@ QVector<double> BeatTracking::calculateWindowWeights(int windowSize)
8385
{
8486
QVector<double> returnVector(windowSize);
8587

86-
// pre calculate hanningz window
88+
// pre calculate Hann window
8789
for (int i = 0; i < windowSize; i++)
88-
returnVector[i] = 0.5 * (1.0 - qCos(M_PI * M_PI * i / (windowSize)));
90+
returnVector[i] = 0.5 * (1.0 - cos(2 * M_PI * i / (windowSize - 1)));
8991

9092
return returnVector;
9193
}
@@ -121,37 +123,60 @@ bool BeatTracking::processAudio(int16_t * buffer, int bufferSize)
121123

122124
for (int i = 0; i < bufferSize; ++i)
123125
{
124-
// mixdown channels - Maybe do this before calling the function?
125-
m_windowBuffer += 0;
126-
for (unsigned int j = 0; j < m_channels; j++)
127-
m_windowBuffer[i] += static_cast<double>(buffer[i * m_channels + j]) / m_channels;
126+
m_windowBuffer.append(0.0);
127+
int idx = m_windowBuffer.size() - 1;
128128

129-
m_windowBuffer[i] += m_windowBuffer[i] / 32768.;
129+
for (unsigned int j = 0; j < m_channels; ++j)
130+
m_windowBuffer[idx] += static_cast<double>(buffer[i * m_channels + j]) / m_channels;
131+
132+
m_windowBuffer[idx] /= 32768.0;
130133
}
131134

132135
// 1024 windows size, 512 advance between frames
133136
while (m_windowBuffer.size() > m_windowSize)
134137
{
135-
memcpy(m_fftInputBuffer, m_windowBuffer.data(), m_windowSize);
138+
memcpy(m_fftInputBuffer, m_windowBuffer.constData(), m_windowSize * sizeof(double));
136139

140+
// Apply window
137141
for (int i = 0; i< m_windowSize; i++)
138142
m_fftInputBuffer[i] = m_fftInputBuffer[i] * m_windowWeights[i];
139143

140-
#ifdef HAS_FFTW3
141-
fftw_execute(m_planForward);
142-
#endif
144+
// compute RMS for silence gate on the windowed block ---
145+
double rms = 0.0;
146+
for (int i = 0; i < m_windowSize; ++i)
147+
{
148+
double s = m_fftInputBuffer[i];
149+
rms += s * s;
150+
}
151+
rms = std::sqrt(rms / m_windowSize);
143152

144-
QVector<double> magnitudes(m_windowSize/2, 0);
145153
double onsetValue = 0.0;
146-
for (int i = 0; i < m_windowSize/2; i++)
154+
155+
// If gate is active and level is below threshold, treat as silence
156+
if (m_silenceGateThreshold > 0.0 && rms < m_silenceGateThreshold)
147157
{
148-
double mag = qSqrt((reinterpret_cast<fftw_complex*>(m_fftOutputBuffer)[i][0] * reinterpret_cast<fftw_complex*>(m_fftOutputBuffer)[i][0]) +
149-
(reinterpret_cast<fftw_complex*>(m_fftOutputBuffer)[i][1] * reinterpret_cast<fftw_complex*>(m_fftOutputBuffer)[i][1]));
158+
// onsetValue stays 0.0
159+
// we intentionally skip FFT and magnitude update
160+
}
161+
else
162+
{
163+
#ifdef HAS_FFTW3
164+
fftw_execute(m_planForward);
150165

151-
if (mag > magnitudes[i])
152-
onsetValue += (mag - magnitudes[i]);
166+
auto *fftOut = reinterpret_cast<fftw_complex*>(m_fftOutputBuffer);
153167

154-
magnitudes[i] = mag;
168+
for (int i = 0; i < m_windowSize / 2; ++i)
169+
{
170+
double re = fftOut[i][0];
171+
double im = fftOut[i][1];
172+
double mag = std::sqrt(re*re + im*im);
173+
174+
if (mag > m_prevMagnitudes[i])
175+
onsetValue += (mag - m_prevMagnitudes[i]);
176+
177+
m_prevMagnitudes[i] = mag;
178+
}
179+
#endif
155180
}
156181

157182
if (m_tOnsetValues.size() == m_onsetWindowSize)
@@ -232,16 +257,17 @@ bool BeatTracking::processAudio(int16_t * buffer, int bufferSize)
232257
}
233258
}
234259

235-
m_currentBPM = (44100 * 60) / (m_hopSize * m_identifiedLag);
236-
m_currentMs = m_identifiedLag * m_hopSize / 44.1;
260+
m_currentBPM = (m_sampleRate * 60.0) / (m_hopSize * m_identifiedLag);
261+
m_currentMs = (m_identifiedLag * m_hopSize * 1000.0) / m_sampleRate;
237262

238263
// Beat Tracking and phase detection
239264
int cMax = qFloor(m_tOnsetValues.size() / m_identifiedLag);
240265
QVector<double> phaseOnsetValues(m_tOnsetValues.size(), 0.0);
241266

242267
// reverse onset values and add weighting so that most recent events are more likely
243-
for (int i = 0; i < phaseOnsetValues.size(); i++)
244-
phaseOnsetValues[i] = m_tOnsetValues[m_tOnsetValues.size() - i - 1] * qExp(-1.0*i*(qLn(2)/m_identifiedLag));
268+
double decay = qLn(2) / m_identifiedLag;
269+
for (int i = 0; i < phaseOnsetValues.size(); ++i)
270+
phaseOnsetValues[i] = m_tOnsetValues[m_tOnsetValues.size() - i - 1] * qExp(-i * decay);
245271

246272
// phase calculation by autocorrelation with train of impulses
247273
QVector<double> phaseValues(phaseOnsetValues.size(), 0.0);
@@ -321,17 +347,15 @@ bool BeatTracking::processAudio(int16_t * buffer, int bufferSize)
321347
return isBeat;
322348
}
323349

324-
QVector<double> BeatTracking::getOnsetCorrelation(QList<double> onsetValues)
350+
QVector<double> BeatTracking::getOnsetCorrelation(const QList<double> &onsetValues)
325351
{
326352
QVector<double> autoCorr(onsetValues.size());
327353
for (int l = 0; l < onsetValues.size(); l++)
328354
{
329355
double divider = qAbs(l - onsetValues.size());
330356
double sum = 0.0;
331357
for (int i = l; i < onsetValues.size(); i++)
332-
{
333358
sum += onsetValues[i] * onsetValues[i - l];
334-
}
335359

336360
autoCorr[l] = sum / divider;
337361
}
@@ -357,12 +381,12 @@ QVector<double> BeatTracking::getOnsetCorrelation(QList<double> onsetValues)
357381
return combRes;
358382
}
359383

360-
int BeatTracking::getPredictedAcfLag(QVector<double> roCorr)
384+
int BeatTracking::getPredictedAcfLag(const QVector<double> &roCorr)
361385
{
362386
QVector<double> tps2(roCorr.size()/2);
363387
QVector<double> tps3(roCorr.size()/2);
364388

365-
double max2I = 0.0, max3I = 0.0;
389+
int max2I = 0.0, max3I = 0.0;
366390
double max2 = 0.0, max3 = 0.0;
367391
for (int r = 1; r < roCorr.size() / 2 - 1; r++)
368392
{
@@ -390,7 +414,7 @@ int BeatTracking::getPredictedAcfLag(QVector<double> roCorr)
390414
return max3I;
391415
}
392416

393-
QVector<double> BeatTracking::calculateBiquadFilter(QList<double> values)
417+
QVector<double> BeatTracking::calculateBiquadFilter(const QList<double> &values)
394418
{
395419
QVector<double> processed;
396420
processed.fill(0.0, values.size());
@@ -425,7 +449,7 @@ QVector<double> BeatTracking::calculateBiquadFilter(QList<double> values)
425449
return processed;
426450
}
427451

428-
double BeatTracking::getMean(QVector<double> values)
452+
double BeatTracking::getMean(const QVector<double> &values)
429453
{
430454
double mean = 0.0;
431455
for (double value : values)
@@ -438,16 +462,28 @@ double BeatTracking::getMean(QVector<double> values)
438462
double BeatTracking::getMedian(QVector<double> values)
439463
{
440464
std::sort(values.begin(), values.end());
465+
int n = values.size();
466+
if (n == 0)
467+
return 0.0; // or throw/assert
441468

442-
return values[qFloor(values.size() / 2) + 1];
469+
if (n % 2 == 1)
470+
return values[n / 2];
471+
else
472+
return 0.5 * (values[n/2 - 1] + values[n/2]);
443473
}
444474

445-
double BeatTracking::getQuadraticValue(int position, QVector<double> vector)
475+
double BeatTracking::getQuadraticValue(int position, const QVector<double> &v)
446476
{
447-
double prevValue = 0;
477+
if (position <= 0 || position >= v.size() - 1)
478+
return static_cast<double>(position);
448479

449-
if (position > 0)
450-
prevValue = vector[position - 1];
480+
double y0 = v[position - 1];
481+
double y1 = v[position];
482+
double y2 = v[position + 1];
451483

452-
return static_cast<double>(position) + 0.5 * (prevValue - vector[position + 1]) / (prevValue - 2 * vector[position] + vector[position + 1]);
453-
}
484+
double denom = (y0 - 2 * y1 + y2);
485+
if (denom == 0.0)
486+
return static_cast<double>(position);
487+
488+
return position + 0.5 * (y0 - y2) / denom;
489+
}

engine/audio/src/beattracking.h

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
#include <QObject>
2424
#ifdef HAS_FFTW3
25-
#include "fftw3.h"
25+
#include "fftw3.h"
2626
#endif
2727

2828
/** @addtogroup engine Engine
@@ -40,9 +40,9 @@ class BeatTracking : public QObject
4040
Q_OBJECT
4141

4242
public:
43-
BeatTracking(int channels, QObject * parent = nullptr);
43+
BeatTracking(int channels, QObject *parent = nullptr);
4444
~BeatTracking();
45-
bool processAudio(int16_t * buffer, int bufferSize);
45+
bool processAudio(int16_t *buffer, int bufferSize);
4646

4747
private:
4848
enum PredictionState{ACF, CONTINUITY};
@@ -57,42 +57,40 @@ class BeatTracking : public QObject
5757
int m_hopSize;
5858
int m_onsetWindowSize;
5959

60-
float * currentFrame;
61-
6260
// FFT information
63-
double * m_fftInputBuffer;
64-
void * m_fftOutputBuffer;
61+
double *m_fftInputBuffer;
62+
void *m_fftOutputBuffer;
63+
#ifdef HAS_FFTW3
6564
fftw_plan m_planForward;
66-
65+
#endif
6766

6867
// stored values that are currently processed
6968
QVector<double> m_windowBuffer;
7069

7170
// weighting storage
7271
QVector<double> m_windowWeights;
73-
QVector<double> m_onsetWeights;
7472

7573
// methods
7674
QVector<double> calculateWindowWeights(int windowSize);
77-
QVector<double> calculateBiquadFilter(QList<double> values);
78-
QVector<double> getGaussianWeighting(int length, double tLag);
7975
QVector<double> getRaileighFilterBank(int length, double tLag);
80-
QVector<double> getOnsetCorrelation(QList<double> onsetValues);
81-
int getPredictedAcfLag(QVector<double> oCorr);
82-
double getMean(QVector<double> values);
83-
double getMedian(QVector<double> values);
84-
double getQuadraticValue(int position, QVector<double> vector);
76+
QVector<double> getGaussianWeighting(int length, double tLag);
77+
int getPredictedAcfLag(const QVector<double> &oCorr);
78+
QVector<double> getOnsetCorrelation(const QList<double> &onsetValues);
79+
QVector<double> calculateBiquadFilter(const QList<double> &values);
80+
double getMean(const QVector<double> &values);
81+
double getMedian(QVector<double> values); // keep by value since we sort
82+
double getQuadraticValue(int position, const QVector<double> &vector);
8583

8684

8785
QVector<double> m_raileighFilterBank;
8886
QVector<double> m_gaussianFilterBank;
87+
QVector<double> m_prevMagnitudes;
8988

9089
// onset storage
9190
QList<double> m_tOnsetValues;
9291
QList<double> m_onsetValuesProcessed;
9392

9493
// consistency - Context dependent model
95-
double m_lastDifference;
9694
double m_lastLag;
9795
int m_consistencyCount;
9896
double m_continuityDerivation;
@@ -106,6 +104,8 @@ class BeatTracking : public QObject
106104
double m_identifiedLag;
107105
double m_currentBPM;
108106
double m_currentMs;
107+
108+
double m_silenceGateThreshold; // RMS threshold; 0.0 disables gate
109109
};
110110

111111
/** @} */

0 commit comments

Comments
 (0)