Fix computing features for CED audio tagging models. (#1341)
See also https://github.com/RicherMans/CED/blob/main/onnx_inference_with_kaldi.py
This commit is contained in:
@@ -8,6 +8,7 @@
|
||||
#include <cassert>
|
||||
#include <cmath>
|
||||
#include <iomanip>
|
||||
#include <limits>
|
||||
#include <utility>
|
||||
|
||||
#include "kaldi-native-fbank/csrc/online-feature.h"
|
||||
@@ -110,7 +111,7 @@ class OfflineStream::Impl {
|
||||
config_.sampling_rate = opts_.frame_opts.samp_freq;
|
||||
}
|
||||
|
||||
explicit Impl(CEDTag /*tag*/) {
|
||||
explicit Impl(CEDTag /*tag*/) : is_ced_(true) {
|
||||
// see
|
||||
// https://github.com/RicherMans/CED/blob/main/onnx_inference_with_kaldi.py
|
||||
|
||||
@@ -123,7 +124,9 @@ class OfflineStream::Impl {
|
||||
|
||||
opts_.frame_opts.samp_freq = 16000; // fixed to 16000
|
||||
opts_.mel_opts.num_bins = 64;
|
||||
opts_.mel_opts.low_freq = 0;
|
||||
opts_.mel_opts.high_freq = 8000;
|
||||
opts_.use_log_fbank = false;
|
||||
|
||||
config_.sampling_rate = opts_.frame_opts.samp_freq;
|
||||
|
||||
@@ -216,6 +219,10 @@ class OfflineStream::Impl {
|
||||
|
||||
NemoNormalizeFeatures(features.data(), n, feature_dim);
|
||||
|
||||
if (is_ced_) {
|
||||
AmplitudeToDB(features.data(), features.size());
|
||||
}
|
||||
|
||||
return features;
|
||||
}
|
||||
|
||||
@@ -226,6 +233,32 @@ class OfflineStream::Impl {
|
||||
const ContextGraphPtr &GetContextGraph() const { return context_graph_; }
|
||||
|
||||
private:
|
||||
// see
|
||||
// https://github.com/pytorch/audio/blob/main/src/torchaudio/functional/functional.py#L359
|
||||
void AmplitudeToDB(float *p, int32_t n) const {
|
||||
float multiplier = 10;
|
||||
float top_db = 120;
|
||||
float amin = 1e-10;
|
||||
|
||||
float max_x = std::numeric_limits<float>::min();
|
||||
|
||||
for (int32_t i = 0; i != n; ++i) {
|
||||
float x = p[i];
|
||||
x = (x > amin) ? x : amin;
|
||||
x = std::log10f(x) * multiplier;
|
||||
|
||||
max_x = (x > max_x) ? x : max_x;
|
||||
p[i] = x;
|
||||
}
|
||||
|
||||
float d = max_x - top_db;
|
||||
for (int32_t i = 0; i != n; ++i) {
|
||||
float x = p[i];
|
||||
x = (x > d) ? x : d;
|
||||
p[i] = x;
|
||||
}
|
||||
}
|
||||
|
||||
void NemoNormalizeFeatures(float *p, int32_t num_frames,
|
||||
int32_t feature_dim) const {
|
||||
if (config_.nemo_normalize_type.empty()) {
|
||||
@@ -266,6 +299,7 @@ class OfflineStream::Impl {
|
||||
knf::MfccOptions mfcc_opts_;
|
||||
OfflineRecognitionResult r_;
|
||||
ContextGraphPtr context_graph_;
|
||||
bool is_ced_ = false;
|
||||
};
|
||||
|
||||
OfflineStream::OfflineStream(const FeatureExtractorConfig &config /*= {}*/,
|
||||
|
||||
Reference in New Issue
Block a user