Configurable low_freq high_freq, dithering (#664)

This commit is contained in:
Karel Vesely
2024-03-22 14:41:44 +01:00
committed by GitHub
parent 2fc1201924
commit eaec4c83c2
10 changed files with 96 additions and 15 deletions

View File

@@ -25,6 +25,19 @@ void FeatureExtractorConfig::Register(ParseOptions *po) {
po->Register("feat-dim", &feature_dim,
"Feature dimension. Must match the one expected by the model.");
po->Register("low-freq", &low_freq,
"Low cutoff frequency for mel bins");
po->Register("high-freq", &high_freq,
"High cutoff frequency for mel bins "
"(if <= 0, offset from Nyquist)");
po->Register("dither", &dither,
"Dithering constant (0.0 means no dither). "
"By default the audio samples are in range [-1,+1], "
"so 0.00003 is a good value, "
"equivalent to the default 1.0 from kaldi");
}
std::string FeatureExtractorConfig::ToString() const {
@@ -32,7 +45,10 @@ std::string FeatureExtractorConfig::ToString() const {
os << "FeatureExtractorConfig(";
os << "sampling_rate=" << sampling_rate << ", ";
os << "feature_dim=" << feature_dim << ")";
os << "feature_dim=" << feature_dim << ", ";
os << "low_freq=" << low_freq << ", ";
os << "high_freq=" << high_freq << ", ";
os << "dither=" << dither << ")";
return os.str();
}
@@ -40,7 +56,7 @@ std::string FeatureExtractorConfig::ToString() const {
class FeatureExtractor::Impl {
public:
explicit Impl(const FeatureExtractorConfig &config) : config_(config) {
opts_.frame_opts.dither = 0;
opts_.frame_opts.dither = config.dither;
opts_.frame_opts.snip_edges = config.snip_edges;
opts_.frame_opts.samp_freq = config.sampling_rate;
opts_.frame_opts.frame_shift_ms = config.frame_shift_ms;
@@ -50,13 +66,9 @@ class FeatureExtractor::Impl {
opts_.mel_opts.num_bins = config.feature_dim;
// Please see
// https://github.com/lhotse-speech/lhotse/blob/master/lhotse/features/fbank.py#L27
// and
// https://github.com/k2-fsa/sherpa-onnx/issues/514
opts_.mel_opts.high_freq = -400;
opts_.mel_opts.high_freq = config.high_freq;
opts_.mel_opts.low_freq = config.low_freq;
opts_.mel_opts.low_freq = config.low_freq;
opts_.mel_opts.is_librosa = config.is_librosa;
fbank_ = std::make_unique<knf::OnlineFbank>(opts_);

View File

@@ -21,6 +21,27 @@ struct FeatureExtractorConfig {
// Feature dimension
int32_t feature_dim = 80;
// minimal frequency for Mel-filterbank, in Hz
float low_freq = 20.0f;
// maximal frequency of Mel-filterbank
// in Hz; negative value is subtracted from Nyquist freq.:
// i.e. for sampling_rate 16000 / 2 - 400 = 7600Hz
//
// Please see
// https://github.com/lhotse-speech/lhotse/blob/master/lhotse/features/fbank.py#L27
// and
// https://github.com/k2-fsa/sherpa-onnx/issues/514
float high_freq = -400.0f;
// dithering constant, useful for signals with hard-zeroes in non-speech parts
// this prevents large negative values in log-mel filterbanks
//
// In k2, audio samples are in range [-1..+1], in kaldi the range was
// [-32k..+32k], so the value 0.00003 is equivalent to kaldi default 1.0
//
float dither = 0.0f; // dithering disabled by default
// Set internally by some models, e.g., paraformer sets it to false.
// This parameter is not exposed to users from the commandline
// If true, the feature extractor expects inputs to be normalized to
@@ -31,7 +52,6 @@ struct FeatureExtractorConfig {
bool snip_edges = false;
float frame_shift_ms = 10.0f; // in milliseconds.
float frame_length_ms = 25.0f; // in milliseconds.
int32_t low_freq = 20;
bool is_librosa = false;
bool remove_dc_offset = true; // Subtract mean of wave before FFT.
std::string window_type = "povey"; // e.g. Hamming window

View File

@@ -72,6 +72,8 @@ class KeywordSpotterTransducerImpl : public KeywordSpotterImpl {
unk_id_ = sym_["<unk>"];
}
model_->SetFeatureDim(config.feat_config.feature_dim);
InitKeywords();
decoder_ = std::make_unique<TransducerKeywordDecoder>(
@@ -89,6 +91,8 @@ class KeywordSpotterTransducerImpl : public KeywordSpotterImpl {
unk_id_ = sym_["<unk>"];
}
model_->SetFeatureDim(config.feat_config.feature_dim);
InitKeywords(mgr);
decoder_ = std::make_unique<TransducerKeywordDecoder>(

View File

@@ -90,6 +90,8 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
unk_id_ = sym_["<unk>"];
}
model_->SetFeatureDim(config.feat_config.feature_dim);
if (config.decoding_method == "modified_beam_search") {
if (!config_.hotwords_file.empty()) {
InitHotwords();
@@ -123,6 +125,8 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
unk_id_ = sym_["<unk>"];
}
model_->SetFeatureDim(config.feat_config.feature_dim);
if (config.decoding_method == "modified_beam_search") {
#if 0
// TODO(fangjun): Implement it

View File

@@ -61,6 +61,16 @@ class OnlineTransducerModel {
*/
virtual std::vector<Ort::Value> GetEncoderInitStates() = 0;
/** Set feature dim.
*
* This is used in `OnlineZipformer2TransducerModel`,
* to pass `feature_dim` for `GetEncoderInitStates()`.
*
* This has to be called before GetEncoderInitStates(), so the `encoder_embed`
* init state has the correct `embed_dim` of its output.
*/
virtual void SetFeatureDim(int32_t feature_dim) { }
/** Run the encoder.
*
* @param features A tensor of shape (N, T, C). It is changed in-place.

View File

@@ -403,7 +403,10 @@ OnlineZipformer2TransducerModel::GetEncoderInitStates() {
}
{
std::array<int64_t, 4> s{1, 128, 3, 19};
SHERPA_ONNX_CHECK_NE(feature_dim_, 0);
int32_t embed_dim = (((feature_dim_ - 1) / 2) - 1) / 2;
std::array<int64_t, 4> s{1, 128, 3, embed_dim};
auto v = Ort::Value::CreateTensor<float>(allocator_, s.data(), s.size());
Fill(&v, 0);
ans.push_back(std::move(v));

View File

@@ -37,6 +37,10 @@ class OnlineZipformer2TransducerModel : public OnlineTransducerModel {
std::vector<Ort::Value> GetEncoderInitStates() override;
void SetFeatureDim(int32_t feature_dim) override {
feature_dim_ = feature_dim;
}
std::pair<Ort::Value, std::vector<Ort::Value>> RunEncoder(
Ort::Value features, std::vector<Ort::Value> states,
Ort::Value processed_frames) override;
@@ -101,6 +105,7 @@ class OnlineZipformer2TransducerModel : public OnlineTransducerModel {
int32_t context_size_ = 0;
int32_t vocab_size_ = 0;
int32_t feature_dim_ = 0;
};
} // namespace sherpa_onnx