diff --git a/runtime/core/decoder/params.h b/runtime/core/decoder/params.h index 2b2344143..a3c1449a5 100644 --- a/runtime/core/decoder/params.h +++ b/runtime/core/decoder/params.h @@ -62,6 +62,7 @@ DEFINE_int32(core_number, 1, "Core number of process"); // FeaturePipelineConfig flags DEFINE_int32(num_bins, 80, "num mel bins for fbank feature"); DEFINE_int32(sample_rate, 16000, "sample rate for audio"); +DEFINE_string(feat_type, "kaldi", "Type of feature extraction: kaldi, whisper"); // TLG fst DEFINE_string(fst_path, "", "TLG fst path"); @@ -115,9 +116,20 @@ DEFINE_int32(language_type, 0, DEFINE_bool(lowercase, true, "lowercase final result if needed"); namespace wenet { + +FeatureType StringToFeatureType(const std::string& feat_type_str) { + if (feat_type_str == "kaldi") + return FeatureType::kKaldi; + else if (feat_type_str == "whisper") + return FeatureType::kWhisper; + else + throw std::invalid_argument("Unsupported feat type!"); +} + std::shared_ptr InitFeaturePipelineConfigFromFlags() { + FeatureType feat_type = StringToFeatureType(FLAGS_feat_type); auto feature_config = std::make_shared( - FLAGS_num_bins, FLAGS_sample_rate); + FLAGS_num_bins, FLAGS_sample_rate, feat_type); return feature_config; } diff --git a/runtime/core/frontend/fbank.h b/runtime/core/frontend/fbank.h index 5a650dc03..ad42307ca 100644 --- a/runtime/core/frontend/fbank.h +++ b/runtime/core/frontend/fbank.h @@ -28,9 +28,39 @@ namespace wenet { // This code is based on kaldi Fbank implementation, please see // https://github.com/kaldi-asr/kaldi/blob/master/src/feat/feature-fbank.cc + +static const int kS16AbsMax = 1 << 15; + +enum class WindowType { + kPovey = 0, + kHanning, +}; + +enum class MelType { + kHTK = 0, + kSlaney, +}; + +enum class NormalizationType { + kKaldi = 0, + kWhisper, +}; + +enum class LogBase { + kBaseE = 0, + kBase10, +}; + class Fbank { public: - Fbank(int num_bins, int sample_rate, int frame_length, int frame_shift) + Fbank(int num_bins, int sample_rate, int frame_length, int frame_shift, + float low_freq = 20, bool pre_emphasis = true, + bool scale_input_to_unit = false, + float log_floor = std::numeric_limits::epsilon(), + LogBase log_base = LogBase::kBaseE, + WindowType window_type = WindowType::kPovey, + MelType mel_type = MelType::kHTK, + NormalizationType norm_type = NormalizationType::kKaldi) : num_bins_(num_bins), sample_rate_(sample_rate), frame_length_(frame_length), @@ -39,7 +69,14 @@ class Fbank { remove_dc_offset_(true), generator_(0), distribution_(0, 1.0), - dither_(0.0) { + dither_(0.0), + low_freq_(low_freq), + high_freq_(sample_rate / 2), + pre_emphasis_(pre_emphasis), + scale_input_to_unit_(scale_input_to_unit), + log_floor_(log_floor), + log_base_(log_base), + norm_type_(norm_type) { fft_points_ = UpperPowerOfTwo(frame_length_); // generate bit reversal table and trigonometric function table const int fft_points_4 = fft_points_ / 4; @@ -47,32 +84,54 @@ class Fbank { sintbl_.resize(fft_points_ + fft_points_4); make_sintbl(fft_points_, sintbl_.data()); make_bitrev(fft_points_, bitrev_.data()); + InitMelFilters(mel_type); + InitWindow(window_type); + } + void InitMelFilters(MelType mel_type) { int num_fft_bins = fft_points_ / 2; float fft_bin_width = static_cast(sample_rate_) / fft_points_; - int low_freq = 20, high_freq = sample_rate_ / 2; - float mel_low_freq = MelScale(low_freq); - float mel_high_freq = MelScale(high_freq); - float mel_freq_delta = (mel_high_freq - mel_low_freq) / (num_bins + 1); + float mel_low_freq = MelScale(low_freq_, mel_type); + float mel_high_freq = MelScale(high_freq_, mel_type); + float mel_freq_delta = (mel_high_freq - mel_low_freq) / (num_bins_ + 1); bins_.resize(num_bins_); center_freqs_.resize(num_bins_); - for (int bin = 0; bin < num_bins; ++bin) { + + for (int bin = 0; bin < num_bins_; ++bin) { float left_mel = mel_low_freq + bin * mel_freq_delta, center_mel = mel_low_freq + (bin + 1) * mel_freq_delta, right_mel = mel_low_freq + (bin + 2) * mel_freq_delta; - center_freqs_[bin] = InverseMelScale(center_mel); + center_freqs_[bin] = InverseMelScale(center_mel, mel_type); std::vector this_bin(num_fft_bins); int first_index = -1, last_index = -1; for (int i = 0; i < num_fft_bins; ++i) { float freq = (fft_bin_width * i); // Center frequency of this fft // bin. - float mel = MelScale(freq); + float mel = MelScale(freq, mel_type); if (mel > left_mel && mel < right_mel) { float weight; - if (mel <= center_mel) - weight = (mel - left_mel) / (center_mel - left_mel); - else - weight = (right_mel - mel) / (right_mel - center_mel); + if (mel_type == MelType::kHTK) { + if (mel <= center_mel) + weight = (mel - left_mel) / (center_mel - left_mel); + else if (mel > center_mel) + weight = (right_mel - mel) / (right_mel - center_mel); + } else if (mel_type == MelType::kSlaney) { + if (mel <= center_mel) { + weight = (InverseMelScale(mel, mel_type) - + InverseMelScale(left_mel, mel_type)) / + (InverseMelScale(center_mel, mel_type) - + InverseMelScale(left_mel, mel_type)); + weight *= 2.0 / (InverseMelScale(right_mel, mel_type) - + InverseMelScale(left_mel, mel_type)); + } else if (mel > center_mel) { + weight = (InverseMelScale(right_mel, mel_type) - + InverseMelScale(mel, mel_type)) / + (InverseMelScale(right_mel, mel_type) - + InverseMelScale(center_mel, mel_type)); + weight *= 2.0 / (InverseMelScale(right_mel, mel_type) - + InverseMelScale(left_mel, mel_type)); + } + } this_bin[i] = weight; if (first_index == -1) first_index = i; last_index = i; @@ -86,12 +145,20 @@ class Fbank { bins_[bin].second[i] = this_bin[first_index + i]; } } + } - // povey window - povey_window_.resize(frame_length_); - double a = M_2PI / (frame_length - 1); - for (int i = 0; i < frame_length; ++i) { - povey_window_[i] = pow(0.5 - 0.5 * cos(a * i), 0.85); + void InitWindow(WindowType window_type) { + window_.resize(frame_length_); + if (window_type == WindowType::kPovey) { + // povey window + double a = M_2PI / (frame_length_ - 1); + for (int i = 0; i < frame_length_; ++i) + window_[i] = pow(0.5 - 0.5 * cos(a * i), 0.85); + } else if (window_type == WindowType::kHanning) { + // periodic hanning window + double a = M_2PI / (frame_length_); + for (int i = 0; i < frame_length_; ++i) + window_[i] = 0.5 * (1.0 - cos(i * a)); } } @@ -105,12 +172,45 @@ class Fbank { int num_bins() const { return num_bins_; } - static inline float InverseMelScale(float mel_freq) { - return 700.0f * (expf(mel_freq / 1127.0f) - 1.0f); + static inline float InverseMelScale(float mel_freq, + MelType mel_type = MelType::kHTK) { + if (mel_type == MelType::kHTK) { + return 700.0f * (expf(mel_freq / 1127.0f) - 1.0f); + } else if (mel_type == MelType::kSlaney) { + float f_min = 0.0; + float f_sp = 200.0f / 3.0f; + float min_log_hz = 1000.0; + float freq = f_min + f_sp * mel_freq; + float min_log_mel = (min_log_hz - f_min) / f_sp; + float logstep = logf(6.4) / 27.0f; + if (mel_freq >= min_log_mel) { + return min_log_hz * expf(logstep * (mel_freq - min_log_mel)); + } else { + return freq; + } + } else { + throw std::invalid_argument("Unsupported mel type!"); + } } - static inline float MelScale(float freq) { - return 1127.0f * logf(1.0f + freq / 700.0f); + static inline float MelScale(float freq, MelType mel_type = MelType::kHTK) { + if (mel_type == MelType::kHTK) { + return 1127.0f * logf(1.0f + freq / 700.0f); + } else if (mel_type == MelType::kSlaney) { + float f_min = 0.0; + float f_sp = 200.0f / 3.0f; + float min_log_hz = 1000.0; + float mel = (freq - f_min) / f_sp; + float min_log_mel = (min_log_hz - f_min) / f_sp; + float logstep = logf(6.4) / 27.0f; + if (freq >= min_log_hz) { + return min_log_mel + logf(freq / min_log_hz) / logstep; + } else { + return mel; + } + } else { + throw std::invalid_argument("Unsupported mel type!"); + } } static int UpperPowerOfTwo(int n) { @@ -125,11 +225,24 @@ class Fbank { (*data)[0] -= coeff * (*data)[0]; } - // Apply povey window on data in place - void Povey(std::vector* data) const { - CHECK_GE(data->size(), povey_window_.size()); - for (size_t i = 0; i < povey_window_.size(); ++i) { - (*data)[i] *= povey_window_[i]; + // Apply window on data in place + void ApplyWindow(std::vector* data) const { + CHECK_GE(data->size(), window_.size()); + for (size_t i = 0; i < window_.size(); ++i) { + (*data)[i] *= window_[i]; + } + } + + void WhisperNorm(std::vector>* feat, + float max_mel_engery) { + int num_frames = feat->size(); + for (int i = 0; i < num_frames; ++i) { + for (int j = 0; j < num_bins_; ++j) { + float energy = (*feat)[i][j]; + if (energy < max_mel_engery - 8) energy = max_mel_engery - 8; + energy = (energy + 4.0) / 4.0; + (*feat)[i][j] = energy; + } } } @@ -137,14 +250,25 @@ class Fbank { int Compute(const std::vector& wave, std::vector>* feat) { int num_samples = wave.size(); + if (num_samples < frame_length_) return 0; int num_frames = 1 + ((num_samples - frame_length_) / frame_shift_); feat->resize(num_frames); std::vector fft_real(fft_points_, 0), fft_img(fft_points_, 0); std::vector power(fft_points_ / 2); + + float max_mel_engery = std::numeric_limits::min(); + for (int i = 0; i < num_frames; ++i) { std::vector data(wave.data() + i * frame_shift_, wave.data() + i * frame_shift_ + frame_length_); + + if (scale_input_to_unit_) { + for (int j = 0; j < frame_length_; ++j) { + data[j] = data[j] / kS16AbsMax; + } + } + // optional add noise if (dither_ != 0.0) { for (size_t j = 0; j < data.size(); ++j) @@ -158,8 +282,10 @@ class Fbank { for (size_t j = 0; j < data.size(); ++j) data[j] -= mean; } - PreEmphasis(0.97, &data); - Povey(&data); + if (pre_emphasis_) { + PreEmphasis(0.97, &data); + } + ApplyWindow(&data); // copy data to fft_real memset(fft_img.data(), 0, sizeof(float) * fft_points_); memset(fft_real.data() + frame_length_, 0, @@ -174,6 +300,7 @@ class Fbank { (*feat)[i].resize(num_bins_); // cepstral coefficients, triangle filter array + for (int j = 0; j < num_bins_; ++j) { float mel_energy = 0.0; int s = bins_[j].first; @@ -182,14 +309,20 @@ class Fbank { } // optional use log if (use_log_) { - if (mel_energy < std::numeric_limits::epsilon()) - mel_energy = std::numeric_limits::epsilon(); - mel_energy = logf(mel_energy); - } + if (mel_energy < log_floor_) mel_energy = log_floor_; + if (log_base_ == LogBase::kBaseE) + mel_energy = logf(mel_energy); + else if (log_base_ == LogBase::kBase10) + mel_energy = log10(mel_energy); + } + if (max_mel_engery < mel_energy) max_mel_engery = mel_energy; (*feat)[i][j] = mel_energy; } } + if (norm_type_ == NormalizationType::kWhisper) + WhisperNorm(feat, max_mel_engery); + return num_frames; } @@ -200,9 +333,17 @@ class Fbank { int fft_points_; bool use_log_; bool remove_dc_offset_; + bool pre_emphasis_; + bool scale_input_to_unit_; + float low_freq_; + float log_floor_; + float high_freq_; + LogBase log_base_; + NormalizationType norm_type_; + std::vector center_freqs_; std::vector>> bins_; - std::vector povey_window_; + std::vector window_; std::default_random_engine generator_; std::normal_distribution distribution_; float dither_; diff --git a/runtime/core/frontend/feature_pipeline.cc b/runtime/core/frontend/feature_pipeline.cc index ab450b15c..91be1d29c 100644 --- a/runtime/core/frontend/feature_pipeline.cc +++ b/runtime/core/frontend/feature_pipeline.cc @@ -23,7 +23,9 @@ FeaturePipeline::FeaturePipeline(const FeaturePipelineConfig& config) : config_(config), feature_dim_(config.num_bins), fbank_(config.num_bins, config.sample_rate, config.frame_length, - config.frame_shift), + config.frame_shift, config.low_freq, config.pre_emphasis, + config.scale_input_to_unit, config.log_floor, config.log_base, + config.window_type, config.mel_type, config.norm_type), num_frames_(0), input_finished_(false) {} diff --git a/runtime/core/frontend/feature_pipeline.h b/runtime/core/frontend/feature_pipeline.h index 9918d6b57..79d598022 100644 --- a/runtime/core/frontend/feature_pipeline.h +++ b/runtime/core/frontend/feature_pipeline.h @@ -15,6 +15,7 @@ #ifndef FRONTEND_FEATURE_PIPELINE_H_ #define FRONTEND_FEATURE_PIPELINE_H_ +#include #include #include #include @@ -26,22 +27,60 @@ namespace wenet { +enum class FeatureType { + kKaldi = 0, + kWhisper, +}; + struct FeaturePipelineConfig { int num_bins; int sample_rate; int frame_length; int frame_shift; - FeaturePipelineConfig(int num_bins, int sample_rate) + float low_freq; + bool pre_emphasis; + bool scale_input_to_unit; + float log_floor; + LogBase log_base; + WindowType window_type; + MelType mel_type; + NormalizationType norm_type; + + FeaturePipelineConfig(int num_bins, int sample_rate, + FeatureType feat_type = FeatureType::kKaldi) : num_bins(num_bins), // 80 dim fbank sample_rate(sample_rate) { // 16k sample rate frame_length = sample_rate / 1000 * 25; // frame length 25ms frame_shift = sample_rate / 1000 * 10; // frame shift 10ms + if (feat_type == FeatureType::kKaldi) { + low_freq = 20.0; + pre_emphasis = true; + log_floor = std::numeric_limits::epsilon(); + log_base = LogBase::kBaseE; + window_type = WindowType::kPovey; + mel_type = MelType::kHTK; + norm_type = NormalizationType::kKaldi; + scale_input_to_unit = false; + } else if (feat_type == FeatureType::kWhisper) { + low_freq = 0.0; + pre_emphasis = false; + log_floor = 1e-10; + log_base = LogBase::kBase10; + window_type = WindowType::kHanning; + mel_type = MelType::kSlaney; + scale_input_to_unit = true; + norm_type = NormalizationType::kWhisper; + } } void Info() const { LOG(INFO) << "feature pipeline config" << " num_bins " << num_bins << " frame_length " << frame_length - << " frame_shift " << frame_shift; + << " frame_shift " << frame_shift << " low_freq " << low_freq + << " preemphasis " << pre_emphasis << " log_floor " << log_floor + << " log_base " << int(log_base) << " window_type " + << int(window_type) << " mel_type " << int(mel_type) + << " norm_type " << int(norm_type); } };