Fix computing features for CED audio tagging models. (#1341)
See also https://github.com/RicherMans/CED/blob/main/onnx_inference_with_kaldi.py
This commit is contained in:
@@ -8,6 +8,7 @@
|
|||||||
#include <cassert>
|
#include <cassert>
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
#include <iomanip>
|
#include <iomanip>
|
||||||
|
#include <limits>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
||||||
#include "kaldi-native-fbank/csrc/online-feature.h"
|
#include "kaldi-native-fbank/csrc/online-feature.h"
|
||||||
@@ -110,7 +111,7 @@ class OfflineStream::Impl {
|
|||||||
config_.sampling_rate = opts_.frame_opts.samp_freq;
|
config_.sampling_rate = opts_.frame_opts.samp_freq;
|
||||||
}
|
}
|
||||||
|
|
||||||
explicit Impl(CEDTag /*tag*/) {
|
explicit Impl(CEDTag /*tag*/) : is_ced_(true) {
|
||||||
// see
|
// see
|
||||||
// https://github.com/RicherMans/CED/blob/main/onnx_inference_with_kaldi.py
|
// https://github.com/RicherMans/CED/blob/main/onnx_inference_with_kaldi.py
|
||||||
|
|
||||||
@@ -123,7 +124,9 @@ class OfflineStream::Impl {
|
|||||||
|
|
||||||
opts_.frame_opts.samp_freq = 16000; // fixed to 16000
|
opts_.frame_opts.samp_freq = 16000; // fixed to 16000
|
||||||
opts_.mel_opts.num_bins = 64;
|
opts_.mel_opts.num_bins = 64;
|
||||||
|
opts_.mel_opts.low_freq = 0;
|
||||||
opts_.mel_opts.high_freq = 8000;
|
opts_.mel_opts.high_freq = 8000;
|
||||||
|
opts_.use_log_fbank = false;
|
||||||
|
|
||||||
config_.sampling_rate = opts_.frame_opts.samp_freq;
|
config_.sampling_rate = opts_.frame_opts.samp_freq;
|
||||||
|
|
||||||
@@ -216,6 +219,10 @@ class OfflineStream::Impl {
|
|||||||
|
|
||||||
NemoNormalizeFeatures(features.data(), n, feature_dim);
|
NemoNormalizeFeatures(features.data(), n, feature_dim);
|
||||||
|
|
||||||
|
if (is_ced_) {
|
||||||
|
AmplitudeToDB(features.data(), features.size());
|
||||||
|
}
|
||||||
|
|
||||||
return features;
|
return features;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -226,6 +233,32 @@ class OfflineStream::Impl {
|
|||||||
const ContextGraphPtr &GetContextGraph() const { return context_graph_; }
|
const ContextGraphPtr &GetContextGraph() const { return context_graph_; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
// see
|
||||||
|
// https://github.com/pytorch/audio/blob/main/src/torchaudio/functional/functional.py#L359
|
||||||
|
void AmplitudeToDB(float *p, int32_t n) const {
|
||||||
|
float multiplier = 10;
|
||||||
|
float top_db = 120;
|
||||||
|
float amin = 1e-10;
|
||||||
|
|
||||||
|
float max_x = std::numeric_limits<float>::min();
|
||||||
|
|
||||||
|
for (int32_t i = 0; i != n; ++i) {
|
||||||
|
float x = p[i];
|
||||||
|
x = (x > amin) ? x : amin;
|
||||||
|
x = std::log10f(x) * multiplier;
|
||||||
|
|
||||||
|
max_x = (x > max_x) ? x : max_x;
|
||||||
|
p[i] = x;
|
||||||
|
}
|
||||||
|
|
||||||
|
float d = max_x - top_db;
|
||||||
|
for (int32_t i = 0; i != n; ++i) {
|
||||||
|
float x = p[i];
|
||||||
|
x = (x > d) ? x : d;
|
||||||
|
p[i] = x;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void NemoNormalizeFeatures(float *p, int32_t num_frames,
|
void NemoNormalizeFeatures(float *p, int32_t num_frames,
|
||||||
int32_t feature_dim) const {
|
int32_t feature_dim) const {
|
||||||
if (config_.nemo_normalize_type.empty()) {
|
if (config_.nemo_normalize_type.empty()) {
|
||||||
@@ -266,6 +299,7 @@ class OfflineStream::Impl {
|
|||||||
knf::MfccOptions mfcc_opts_;
|
knf::MfccOptions mfcc_opts_;
|
||||||
OfflineRecognitionResult r_;
|
OfflineRecognitionResult r_;
|
||||||
ContextGraphPtr context_graph_;
|
ContextGraphPtr context_graph_;
|
||||||
|
bool is_ced_ = false;
|
||||||
};
|
};
|
||||||
|
|
||||||
OfflineStream::OfflineStream(const FeatureExtractorConfig &config /*= {}*/,
|
OfflineStream::OfflineStream(const FeatureExtractorConfig &config /*= {}*/,
|
||||||
|
|||||||
Reference in New Issue
Block a user