diff --git a/sherpa-onnx/csrc/features.cc b/sherpa-onnx/csrc/features.cc index 7bb47850..2dab27fd 100644 --- a/sherpa-onnx/csrc/features.cc +++ b/sherpa-onnx/csrc/features.cc @@ -6,74 +6,108 @@ #include #include +#include // NOLINT #include +#include "kaldi-native-fbank/csrc/online-feature.h" + namespace sherpa_onnx { -FeatureExtractor::FeatureExtractor() { - opts_.frame_opts.dither = 0; - opts_.frame_opts.snip_edges = false; - opts_.frame_opts.samp_freq = 16000; +class FeatureExtractor::Impl { + public: + Impl(int32_t sampling_rate, int32_t feature_dim) { + opts_.frame_opts.dither = 0; + opts_.frame_opts.snip_edges = false; + opts_.frame_opts.samp_freq = sampling_rate; - // cache 100 seconds of feature frames, which is more than enough - // for real needs - opts_.frame_opts.max_feature_vectors = 100 * 100; + // cache 100 seconds of feature frames, which is more than enough + // for real needs + opts_.frame_opts.max_feature_vectors = 100 * 100; - opts_.mel_opts.num_bins = 80; // feature dim + opts_.mel_opts.num_bins = feature_dim; - fbank_ = std::make_unique(opts_); -} + fbank_ = std::make_unique(opts_); + } -FeatureExtractor::FeatureExtractor(const knf::FbankOptions &opts) - : opts_(opts) { - fbank_ = std::make_unique(opts_); -} + void AcceptWaveform(float sampling_rate, const float *waveform, int32_t n) { + std::lock_guard lock(mutex_); + fbank_->AcceptWaveform(sampling_rate, waveform, n); + } + + void InputFinished() { + std::lock_guard lock(mutex_); + fbank_->InputFinished(); + } + + int32_t NumFramesReady() const { + std::lock_guard lock(mutex_); + return fbank_->NumFramesReady(); + } + + bool IsLastFrame(int32_t frame) const { + std::lock_guard lock(mutex_); + return fbank_->IsLastFrame(frame); + } + + std::vector GetFrames(int32_t frame_index, int32_t n) const { + if (frame_index + n > NumFramesReady()) { + fprintf(stderr, "%d + %d > %d\n", frame_index, n, NumFramesReady()); + exit(-1); + } + std::lock_guard lock(mutex_); + + int32_t feature_dim = fbank_->Dim(); + std::vector features(feature_dim * n); + + float *p = features.data(); + + for (int32_t i = 0; i != n; ++i) { + const float *f = fbank_->GetFrame(i + frame_index); + std::copy(f, f + feature_dim, p); + p += feature_dim; + } + + return features; + } + + void Reset() { fbank_ = std::make_unique(opts_); } + + int32_t FeatureDim() const { return opts_.mel_opts.num_bins; } + + private: + std::unique_ptr fbank_; + knf::FbankOptions opts_; + mutable std::mutex mutex_; +}; + +FeatureExtractor::FeatureExtractor(int32_t sampling_rate /*=16000*/, + int32_t feature_dim /*=80*/) + : impl_(std::make_unique(sampling_rate, feature_dim)) {} + +FeatureExtractor::~FeatureExtractor() = default; void FeatureExtractor::AcceptWaveform(float sampling_rate, const float *waveform, int32_t n) { - std::lock_guard lock(mutex_); - fbank_->AcceptWaveform(sampling_rate, waveform, n); + impl_->AcceptWaveform(sampling_rate, waveform, n); } -void FeatureExtractor::InputFinished() { - std::lock_guard lock(mutex_); - fbank_->InputFinished(); -} +void FeatureExtractor::InputFinished() { impl_->InputFinished(); } int32_t FeatureExtractor::NumFramesReady() const { - std::lock_guard lock(mutex_); - return fbank_->NumFramesReady(); + return impl_->NumFramesReady(); } bool FeatureExtractor::IsLastFrame(int32_t frame) const { - std::lock_guard lock(mutex_); - return fbank_->IsLastFrame(frame); + return impl_->IsLastFrame(frame); } std::vector FeatureExtractor::GetFrames(int32_t frame_index, int32_t n) const { - if (frame_index + n > NumFramesReady()) { - fprintf(stderr, "%d + %d > %d\n", frame_index, n, NumFramesReady()); - exit(-1); - } - std::lock_guard lock(mutex_); - - int32_t feature_dim = fbank_->Dim(); - std::vector features(feature_dim * n); - - float *p = features.data(); - - for (int32_t i = 0; i != n; ++i) { - const float *f = fbank_->GetFrame(i + frame_index); - std::copy(f, f + feature_dim, p); - p += feature_dim; - } - - return features; + return impl_->GetFrames(frame_index, n); } -void FeatureExtractor::Reset() { - fbank_ = std::make_unique(opts_); -} +void FeatureExtractor::Reset() { impl_->Reset(); } + +int32_t FeatureExtractor::FeatureDim() const { return impl_->FeatureDim(); } } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/features.h b/sherpa-onnx/csrc/features.h index 8e569c2b..9ff3104b 100644 --- a/sherpa-onnx/csrc/features.h +++ b/sherpa-onnx/csrc/features.h @@ -6,17 +6,19 @@ #define SHERPA_ONNX_CSRC_FEATURES_H_ #include -#include // NOLINT #include -#include "kaldi-native-fbank/csrc/online-feature.h" - namespace sherpa_onnx { class FeatureExtractor { public: - FeatureExtractor(); - explicit FeatureExtractor(const knf::FbankOptions &fbank_opts); + /** + * @param sampling_rate Sampling rate of the data used to train the model. + * @param feature_dim Dimension of the features used to train the model. + */ + explicit FeatureExtractor(int32_t sampling_rate = 16000, + int32_t feature_dim = 80); + ~FeatureExtractor(); /** @param sampling_rate The sampling_rate of the input waveform. Should match @@ -48,12 +50,13 @@ class FeatureExtractor { std::vector GetFrames(int32_t frame_index, int32_t n) const; void Reset(); - int32_t FeatureDim() const { return opts_.mel_opts.num_bins; } + + /// Return feature dim of this extractor + int32_t FeatureDim() const; private: - std::unique_ptr fbank_; - knf::FbankOptions opts_; - mutable std::mutex mutex_; + class Impl; + std::unique_ptr impl_; }; } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/sherpa-onnx.cc b/sherpa-onnx/csrc/sherpa-onnx.cc index 02d984e2..1e18a9be 100644 --- a/sherpa-onnx/csrc/sherpa-onnx.cc +++ b/sherpa-onnx/csrc/sherpa-onnx.cc @@ -2,8 +2,9 @@ // // Copyright (c) 2022-2023 Xiaomi Corporation +#include + #include // NOLINT -#include #include #include @@ -30,14 +31,14 @@ Please refer to https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html for a list of pre-trained models to download. )usage"; - std::cerr << usage << "\n"; + fprintf(stderr, "%s\n", usage); return 0; } std::string tokens = argv[1]; sherpa_onnx::OnlineTransducerModelConfig config; - config.debug = true; + config.debug = false; config.encoder_filename = argv[2]; config.decoder_filename = argv[3]; config.joiner_filename = argv[4]; @@ -47,7 +48,7 @@ for a list of pre-trained models to download. if (argc == 7) { config.num_threads = atoi(argv[6]); } - std::cout << config.ToString().c_str() << "\n"; + fprintf(stderr, "%s\n", config.ToString().c_str()); auto model = sherpa_onnx::OnlineTransducerModel::Create(config); @@ -72,17 +73,17 @@ for a list of pre-trained models to download. sherpa_onnx::ReadWave(wav_filename, expected_sampling_rate, &is_ok); if (!is_ok) { - std::cerr << "Failed to read " << wav_filename << "\n"; + fprintf(stderr, "Failed to read %s\n", wav_filename.c_str()); return -1; } - const float duration = samples.size() / expected_sampling_rate; + float duration = samples.size() / static_cast(expected_sampling_rate); - std::cout << "wav filename: " << wav_filename << "\n"; - std::cout << "wav duration (s): " << duration << "\n"; + fprintf(stderr, "wav filename: %s\n", wav_filename.c_str()); + fprintf(stderr, "wav duration (s): %.3f\n", duration); auto begin = std::chrono::steady_clock::now(); - std::cout << "Started!\n"; + fprintf(stderr, "Started\n"); sherpa_onnx::FeatureExtractor feat_extractor; feat_extractor.AcceptWaveform(expected_sampling_rate, samples.data(), @@ -115,10 +116,10 @@ for a list of pre-trained models to download. text += sym[hyp[i]]; } - std::cout << "Done!\n"; + fprintf(stderr, "Done!\n"); - std::cout << "Recognition result for " << wav_filename << "\n" - << text << "\n"; + fprintf(stderr, "Recognition result for %s:\n%s\n", wav_filename.c_str(), + text.c_str()); auto end = std::chrono::steady_clock::now(); float elapsed_seconds = @@ -126,7 +127,7 @@ for a list of pre-trained models to download. .count() / 1000.; - std::cout << "num threads: " << config.num_threads << "\n"; + fprintf(stderr, "num threads: %d\n", config.num_threads); fprintf(stderr, "Elapsed seconds: %.3f s\n", elapsed_seconds); float rtf = elapsed_seconds / duration;