Configurable low_freq high_freq, dithering (#664)
This commit is contained in:
@@ -1,9 +1,9 @@
|
||||
function(download_kaldi_native_fbank)
|
||||
include(FetchContent)
|
||||
|
||||
set(kaldi_native_fbank_URL "https://github.com/csukuangfj/kaldi-native-fbank/archive/refs/tags/v1.18.7.tar.gz")
|
||||
set(kaldi_native_fbank_URL2 "https://huggingface.co/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/kaldi-native-fbank-1.18.7.tar.gz")
|
||||
set(kaldi_native_fbank_HASH "SHA256=e78fd9d481d83d7d6d1be0012752e6531cb614e030558a3491e3c033cb8e0e4e")
|
||||
set(kaldi_native_fbank_URL "https://github.com/csukuangfj/kaldi-native-fbank/archive/refs/tags/v1.19.1.tar.gz")
|
||||
set(kaldi_native_fbank_URL2 "https://huggingface.co/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/kaldi-native-fbank-1.19.1.tar.gz")
|
||||
set(kaldi_native_fbank_HASH "SHA256=0cae8cbb9ea42916b214e088912f9e8f2f648f54756b305f93f552382f31f904")
|
||||
|
||||
set(KALDI_NATIVE_FBANK_BUILD_TESTS OFF CACHE BOOL "" FORCE)
|
||||
set(KALDI_NATIVE_FBANK_BUILD_PYTHON OFF CACHE BOOL "" FORCE)
|
||||
|
||||
@@ -25,6 +25,19 @@ void FeatureExtractorConfig::Register(ParseOptions *po) {
|
||||
|
||||
po->Register("feat-dim", &feature_dim,
|
||||
"Feature dimension. Must match the one expected by the model.");
|
||||
|
||||
po->Register("low-freq", &low_freq,
|
||||
"Low cutoff frequency for mel bins");
|
||||
|
||||
po->Register("high-freq", &high_freq,
|
||||
"High cutoff frequency for mel bins "
|
||||
"(if <= 0, offset from Nyquist)");
|
||||
|
||||
po->Register("dither", &dither,
|
||||
"Dithering constant (0.0 means no dither). "
|
||||
"By default the audio samples are in range [-1,+1], "
|
||||
"so 0.00003 is a good value, "
|
||||
"equivalent to the default 1.0 from kaldi");
|
||||
}
|
||||
|
||||
std::string FeatureExtractorConfig::ToString() const {
|
||||
@@ -32,7 +45,10 @@ std::string FeatureExtractorConfig::ToString() const {
|
||||
|
||||
os << "FeatureExtractorConfig(";
|
||||
os << "sampling_rate=" << sampling_rate << ", ";
|
||||
os << "feature_dim=" << feature_dim << ")";
|
||||
os << "feature_dim=" << feature_dim << ", ";
|
||||
os << "low_freq=" << low_freq << ", ";
|
||||
os << "high_freq=" << high_freq << ", ";
|
||||
os << "dither=" << dither << ")";
|
||||
|
||||
return os.str();
|
||||
}
|
||||
@@ -40,7 +56,7 @@ std::string FeatureExtractorConfig::ToString() const {
|
||||
class FeatureExtractor::Impl {
|
||||
public:
|
||||
explicit Impl(const FeatureExtractorConfig &config) : config_(config) {
|
||||
opts_.frame_opts.dither = 0;
|
||||
opts_.frame_opts.dither = config.dither;
|
||||
opts_.frame_opts.snip_edges = config.snip_edges;
|
||||
opts_.frame_opts.samp_freq = config.sampling_rate;
|
||||
opts_.frame_opts.frame_shift_ms = config.frame_shift_ms;
|
||||
@@ -50,13 +66,9 @@ class FeatureExtractor::Impl {
|
||||
|
||||
opts_.mel_opts.num_bins = config.feature_dim;
|
||||
|
||||
// Please see
|
||||
// https://github.com/lhotse-speech/lhotse/blob/master/lhotse/features/fbank.py#L27
|
||||
// and
|
||||
// https://github.com/k2-fsa/sherpa-onnx/issues/514
|
||||
opts_.mel_opts.high_freq = -400;
|
||||
opts_.mel_opts.high_freq = config.high_freq;
|
||||
opts_.mel_opts.low_freq = config.low_freq;
|
||||
|
||||
opts_.mel_opts.low_freq = config.low_freq;
|
||||
opts_.mel_opts.is_librosa = config.is_librosa;
|
||||
|
||||
fbank_ = std::make_unique<knf::OnlineFbank>(opts_);
|
||||
|
||||
@@ -21,6 +21,27 @@ struct FeatureExtractorConfig {
|
||||
// Feature dimension
|
||||
int32_t feature_dim = 80;
|
||||
|
||||
// minimal frequency for Mel-filterbank, in Hz
|
||||
float low_freq = 20.0f;
|
||||
|
||||
// maximal frequency of Mel-filterbank
|
||||
// in Hz; negative value is subtracted from Nyquist freq.:
|
||||
// i.e. for sampling_rate 16000 / 2 - 400 = 7600Hz
|
||||
//
|
||||
// Please see
|
||||
// https://github.com/lhotse-speech/lhotse/blob/master/lhotse/features/fbank.py#L27
|
||||
// and
|
||||
// https://github.com/k2-fsa/sherpa-onnx/issues/514
|
||||
float high_freq = -400.0f;
|
||||
|
||||
// dithering constant, useful for signals with hard-zeroes in non-speech parts
|
||||
// this prevents large negative values in log-mel filterbanks
|
||||
//
|
||||
// In k2, audio samples are in range [-1..+1], in kaldi the range was
|
||||
// [-32k..+32k], so the value 0.00003 is equivalent to kaldi default 1.0
|
||||
//
|
||||
float dither = 0.0f; // dithering disabled by default
|
||||
|
||||
// Set internally by some models, e.g., paraformer sets it to false.
|
||||
// This parameter is not exposed to users from the commandline
|
||||
// If true, the feature extractor expects inputs to be normalized to
|
||||
@@ -31,7 +52,6 @@ struct FeatureExtractorConfig {
|
||||
bool snip_edges = false;
|
||||
float frame_shift_ms = 10.0f; // in milliseconds.
|
||||
float frame_length_ms = 25.0f; // in milliseconds.
|
||||
int32_t low_freq = 20;
|
||||
bool is_librosa = false;
|
||||
bool remove_dc_offset = true; // Subtract mean of wave before FFT.
|
||||
std::string window_type = "povey"; // e.g. Hamming window
|
||||
|
||||
@@ -72,6 +72,8 @@ class KeywordSpotterTransducerImpl : public KeywordSpotterImpl {
|
||||
unk_id_ = sym_["<unk>"];
|
||||
}
|
||||
|
||||
model_->SetFeatureDim(config.feat_config.feature_dim);
|
||||
|
||||
InitKeywords();
|
||||
|
||||
decoder_ = std::make_unique<TransducerKeywordDecoder>(
|
||||
@@ -89,6 +91,8 @@ class KeywordSpotterTransducerImpl : public KeywordSpotterImpl {
|
||||
unk_id_ = sym_["<unk>"];
|
||||
}
|
||||
|
||||
model_->SetFeatureDim(config.feat_config.feature_dim);
|
||||
|
||||
InitKeywords(mgr);
|
||||
|
||||
decoder_ = std::make_unique<TransducerKeywordDecoder>(
|
||||
|
||||
@@ -90,6 +90,8 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
|
||||
unk_id_ = sym_["<unk>"];
|
||||
}
|
||||
|
||||
model_->SetFeatureDim(config.feat_config.feature_dim);
|
||||
|
||||
if (config.decoding_method == "modified_beam_search") {
|
||||
if (!config_.hotwords_file.empty()) {
|
||||
InitHotwords();
|
||||
@@ -123,6 +125,8 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
|
||||
unk_id_ = sym_["<unk>"];
|
||||
}
|
||||
|
||||
model_->SetFeatureDim(config.feat_config.feature_dim);
|
||||
|
||||
if (config.decoding_method == "modified_beam_search") {
|
||||
#if 0
|
||||
// TODO(fangjun): Implement it
|
||||
|
||||
@@ -61,6 +61,16 @@ class OnlineTransducerModel {
|
||||
*/
|
||||
virtual std::vector<Ort::Value> GetEncoderInitStates() = 0;
|
||||
|
||||
/** Set feature dim.
|
||||
*
|
||||
* This is used in `OnlineZipformer2TransducerModel`,
|
||||
* to pass `feature_dim` for `GetEncoderInitStates()`.
|
||||
*
|
||||
* This has to be called before GetEncoderInitStates(), so the `encoder_embed`
|
||||
* init state has the correct `embed_dim` of its output.
|
||||
*/
|
||||
virtual void SetFeatureDim(int32_t feature_dim) { }
|
||||
|
||||
/** Run the encoder.
|
||||
*
|
||||
* @param features A tensor of shape (N, T, C). It is changed in-place.
|
||||
|
||||
@@ -403,7 +403,10 @@ OnlineZipformer2TransducerModel::GetEncoderInitStates() {
|
||||
}
|
||||
|
||||
{
|
||||
std::array<int64_t, 4> s{1, 128, 3, 19};
|
||||
SHERPA_ONNX_CHECK_NE(feature_dim_, 0);
|
||||
int32_t embed_dim = (((feature_dim_ - 1) / 2) - 1) / 2;
|
||||
std::array<int64_t, 4> s{1, 128, 3, embed_dim};
|
||||
|
||||
auto v = Ort::Value::CreateTensor<float>(allocator_, s.data(), s.size());
|
||||
Fill(&v, 0);
|
||||
ans.push_back(std::move(v));
|
||||
|
||||
@@ -37,6 +37,10 @@ class OnlineZipformer2TransducerModel : public OnlineTransducerModel {
|
||||
|
||||
std::vector<Ort::Value> GetEncoderInitStates() override;
|
||||
|
||||
void SetFeatureDim(int32_t feature_dim) override {
|
||||
feature_dim_ = feature_dim;
|
||||
}
|
||||
|
||||
std::pair<Ort::Value, std::vector<Ort::Value>> RunEncoder(
|
||||
Ort::Value features, std::vector<Ort::Value> states,
|
||||
Ort::Value processed_frames) override;
|
||||
@@ -101,6 +105,7 @@ class OnlineZipformer2TransducerModel : public OnlineTransducerModel {
|
||||
|
||||
int32_t context_size_ = 0;
|
||||
int32_t vocab_size_ = 0;
|
||||
int32_t feature_dim_ = 0;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
@@ -11,10 +11,17 @@ namespace sherpa_onnx {
|
||||
static void PybindFeatureExtractorConfig(py::module *m) {
|
||||
using PyClass = FeatureExtractorConfig;
|
||||
py::class_<PyClass>(*m, "FeatureExtractorConfig")
|
||||
.def(py::init<int32_t, int32_t>(), py::arg("sampling_rate") = 16000,
|
||||
py::arg("feature_dim") = 80)
|
||||
.def(py::init<int32_t, int32_t, float, float, float>(),
|
||||
py::arg("sampling_rate") = 16000,
|
||||
py::arg("feature_dim") = 80,
|
||||
py::arg("low_freq") = 20.0f,
|
||||
py::arg("high_freq") = -400.0f,
|
||||
py::arg("dither") = 0.0f)
|
||||
.def_readwrite("sampling_rate", &PyClass::sampling_rate)
|
||||
.def_readwrite("feature_dim", &PyClass::feature_dim)
|
||||
.def_readwrite("low_freq", &PyClass::low_freq)
|
||||
.def_readwrite("high_freq", &PyClass::high_freq)
|
||||
.def_readwrite("dither", &PyClass::high_freq)
|
||||
.def("__str__", &PyClass::ToString);
|
||||
}
|
||||
|
||||
|
||||
@@ -41,6 +41,9 @@ class OnlineRecognizer(object):
|
||||
num_threads: int = 2,
|
||||
sample_rate: float = 16000,
|
||||
feature_dim: int = 80,
|
||||
low_freq: float = 20.0,
|
||||
high_freq: float = -400.0,
|
||||
dither: float = 0.0,
|
||||
enable_endpoint_detection: bool = False,
|
||||
rule1_min_trailing_silence: float = 2.4,
|
||||
rule2_min_trailing_silence: float = 1.2,
|
||||
@@ -80,6 +83,16 @@ class OnlineRecognizer(object):
|
||||
Sample rate of the training data used to train the model.
|
||||
feature_dim:
|
||||
Dimension of the feature used to train the model.
|
||||
low_freq:
|
||||
Low cutoff frequency for mel bins in feature extraction.
|
||||
high_freq:
|
||||
High cutoff frequency for mel bins in feature extraction
|
||||
(if <= 0, offset from Nyquist)
|
||||
dither:
|
||||
Dithering constant (0.0 means no dither).
|
||||
By default the audio samples are in range [-1,+1],
|
||||
so dithering constant 0.00003 is a good value,
|
||||
equivalent to the default 1.0 from kaldi
|
||||
enable_endpoint_detection:
|
||||
True to enable endpoint detection. False to disable endpoint
|
||||
detection.
|
||||
@@ -140,6 +153,9 @@ class OnlineRecognizer(object):
|
||||
feat_config = FeatureExtractorConfig(
|
||||
sampling_rate=sample_rate,
|
||||
feature_dim=feature_dim,
|
||||
low_freq=low_freq,
|
||||
high_freq=high_freq,
|
||||
dither=dither,
|
||||
)
|
||||
|
||||
endpoint_config = EndpointConfig(
|
||||
|
||||
Reference in New Issue
Block a user