Refactor feature extractor (#26)

This commit is contained in:
Fangjun Kuang
2023-02-19 09:57:56 +08:00
committed by GitHub
parent cb8f85ff83
commit 710edaa6f9
3 changed files with 105 additions and 67 deletions

View File

@@ -6,74 +6,108 @@
#include <algorithm> #include <algorithm>
#include <memory> #include <memory>
#include <mutex> // NOLINT
#include <vector> #include <vector>
#include "kaldi-native-fbank/csrc/online-feature.h"
namespace sherpa_onnx { namespace sherpa_onnx {
FeatureExtractor::FeatureExtractor() { class FeatureExtractor::Impl {
opts_.frame_opts.dither = 0; public:
opts_.frame_opts.snip_edges = false; Impl(int32_t sampling_rate, int32_t feature_dim) {
opts_.frame_opts.samp_freq = 16000; opts_.frame_opts.dither = 0;
opts_.frame_opts.snip_edges = false;
opts_.frame_opts.samp_freq = sampling_rate;
// cache 100 seconds of feature frames, which is more than enough // cache 100 seconds of feature frames, which is more than enough
// for real needs // for real needs
opts_.frame_opts.max_feature_vectors = 100 * 100; opts_.frame_opts.max_feature_vectors = 100 * 100;
opts_.mel_opts.num_bins = 80; // feature dim opts_.mel_opts.num_bins = feature_dim;
fbank_ = std::make_unique<knf::OnlineFbank>(opts_); fbank_ = std::make_unique<knf::OnlineFbank>(opts_);
} }
FeatureExtractor::FeatureExtractor(const knf::FbankOptions &opts) void AcceptWaveform(float sampling_rate, const float *waveform, int32_t n) {
: opts_(opts) { std::lock_guard<std::mutex> lock(mutex_);
fbank_ = std::make_unique<knf::OnlineFbank>(opts_); fbank_->AcceptWaveform(sampling_rate, waveform, n);
} }
void InputFinished() {
std::lock_guard<std::mutex> lock(mutex_);
fbank_->InputFinished();
}
int32_t NumFramesReady() const {
std::lock_guard<std::mutex> lock(mutex_);
return fbank_->NumFramesReady();
}
bool IsLastFrame(int32_t frame) const {
std::lock_guard<std::mutex> lock(mutex_);
return fbank_->IsLastFrame(frame);
}
std::vector<float> GetFrames(int32_t frame_index, int32_t n) const {
if (frame_index + n > NumFramesReady()) {
fprintf(stderr, "%d + %d > %d\n", frame_index, n, NumFramesReady());
exit(-1);
}
std::lock_guard<std::mutex> lock(mutex_);
int32_t feature_dim = fbank_->Dim();
std::vector<float> features(feature_dim * n);
float *p = features.data();
for (int32_t i = 0; i != n; ++i) {
const float *f = fbank_->GetFrame(i + frame_index);
std::copy(f, f + feature_dim, p);
p += feature_dim;
}
return features;
}
void Reset() { fbank_ = std::make_unique<knf::OnlineFbank>(opts_); }
int32_t FeatureDim() const { return opts_.mel_opts.num_bins; }
private:
std::unique_ptr<knf::OnlineFbank> fbank_;
knf::FbankOptions opts_;
mutable std::mutex mutex_;
};
FeatureExtractor::FeatureExtractor(int32_t sampling_rate /*=16000*/,
int32_t feature_dim /*=80*/)
: impl_(std::make_unique<Impl>(sampling_rate, feature_dim)) {}
FeatureExtractor::~FeatureExtractor() = default;
void FeatureExtractor::AcceptWaveform(float sampling_rate, void FeatureExtractor::AcceptWaveform(float sampling_rate,
const float *waveform, int32_t n) { const float *waveform, int32_t n) {
std::lock_guard<std::mutex> lock(mutex_); impl_->AcceptWaveform(sampling_rate, waveform, n);
fbank_->AcceptWaveform(sampling_rate, waveform, n);
} }
void FeatureExtractor::InputFinished() { void FeatureExtractor::InputFinished() { impl_->InputFinished(); }
std::lock_guard<std::mutex> lock(mutex_);
fbank_->InputFinished();
}
int32_t FeatureExtractor::NumFramesReady() const { int32_t FeatureExtractor::NumFramesReady() const {
std::lock_guard<std::mutex> lock(mutex_); return impl_->NumFramesReady();
return fbank_->NumFramesReady();
} }
bool FeatureExtractor::IsLastFrame(int32_t frame) const { bool FeatureExtractor::IsLastFrame(int32_t frame) const {
std::lock_guard<std::mutex> lock(mutex_); return impl_->IsLastFrame(frame);
return fbank_->IsLastFrame(frame);
} }
std::vector<float> FeatureExtractor::GetFrames(int32_t frame_index, std::vector<float> FeatureExtractor::GetFrames(int32_t frame_index,
int32_t n) const { int32_t n) const {
if (frame_index + n > NumFramesReady()) { return impl_->GetFrames(frame_index, n);
fprintf(stderr, "%d + %d > %d\n", frame_index, n, NumFramesReady());
exit(-1);
}
std::lock_guard<std::mutex> lock(mutex_);
int32_t feature_dim = fbank_->Dim();
std::vector<float> features(feature_dim * n);
float *p = features.data();
for (int32_t i = 0; i != n; ++i) {
const float *f = fbank_->GetFrame(i + frame_index);
std::copy(f, f + feature_dim, p);
p += feature_dim;
}
return features;
} }
void FeatureExtractor::Reset() { void FeatureExtractor::Reset() { impl_->Reset(); }
fbank_ = std::make_unique<knf::OnlineFbank>(opts_);
} int32_t FeatureExtractor::FeatureDim() const { return impl_->FeatureDim(); }
} // namespace sherpa_onnx } // namespace sherpa_onnx

View File

@@ -6,17 +6,19 @@
#define SHERPA_ONNX_CSRC_FEATURES_H_ #define SHERPA_ONNX_CSRC_FEATURES_H_
#include <memory> #include <memory>
#include <mutex> // NOLINT
#include <vector> #include <vector>
#include "kaldi-native-fbank/csrc/online-feature.h"
namespace sherpa_onnx { namespace sherpa_onnx {
class FeatureExtractor { class FeatureExtractor {
public: public:
FeatureExtractor(); /**
explicit FeatureExtractor(const knf::FbankOptions &fbank_opts); * @param sampling_rate Sampling rate of the data used to train the model.
* @param feature_dim Dimension of the features used to train the model.
*/
explicit FeatureExtractor(int32_t sampling_rate = 16000,
int32_t feature_dim = 80);
~FeatureExtractor();
/** /**
@param sampling_rate The sampling_rate of the input waveform. Should match @param sampling_rate The sampling_rate of the input waveform. Should match
@@ -48,12 +50,13 @@ class FeatureExtractor {
std::vector<float> GetFrames(int32_t frame_index, int32_t n) const; std::vector<float> GetFrames(int32_t frame_index, int32_t n) const;
void Reset(); void Reset();
int32_t FeatureDim() const { return opts_.mel_opts.num_bins; }
/// Return feature dim of this extractor
int32_t FeatureDim() const;
private: private:
std::unique_ptr<knf::OnlineFbank> fbank_; class Impl;
knf::FbankOptions opts_; std::unique_ptr<Impl> impl_;
mutable std::mutex mutex_;
}; };
} // namespace sherpa_onnx } // namespace sherpa_onnx

View File

@@ -2,8 +2,9 @@
// //
// Copyright (c) 2022-2023 Xiaomi Corporation // Copyright (c) 2022-2023 Xiaomi Corporation
#include <stdio.h>
#include <chrono> // NOLINT #include <chrono> // NOLINT
#include <iostream>
#include <string> #include <string>
#include <vector> #include <vector>
@@ -30,14 +31,14 @@ Please refer to
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
for a list of pre-trained models to download. for a list of pre-trained models to download.
)usage"; )usage";
std::cerr << usage << "\n"; fprintf(stderr, "%s\n", usage);
return 0; return 0;
} }
std::string tokens = argv[1]; std::string tokens = argv[1];
sherpa_onnx::OnlineTransducerModelConfig config; sherpa_onnx::OnlineTransducerModelConfig config;
config.debug = true; config.debug = false;
config.encoder_filename = argv[2]; config.encoder_filename = argv[2];
config.decoder_filename = argv[3]; config.decoder_filename = argv[3];
config.joiner_filename = argv[4]; config.joiner_filename = argv[4];
@@ -47,7 +48,7 @@ for a list of pre-trained models to download.
if (argc == 7) { if (argc == 7) {
config.num_threads = atoi(argv[6]); config.num_threads = atoi(argv[6]);
} }
std::cout << config.ToString().c_str() << "\n"; fprintf(stderr, "%s\n", config.ToString().c_str());
auto model = sherpa_onnx::OnlineTransducerModel::Create(config); auto model = sherpa_onnx::OnlineTransducerModel::Create(config);
@@ -72,17 +73,17 @@ for a list of pre-trained models to download.
sherpa_onnx::ReadWave(wav_filename, expected_sampling_rate, &is_ok); sherpa_onnx::ReadWave(wav_filename, expected_sampling_rate, &is_ok);
if (!is_ok) { if (!is_ok) {
std::cerr << "Failed to read " << wav_filename << "\n"; fprintf(stderr, "Failed to read %s\n", wav_filename.c_str());
return -1; return -1;
} }
const float duration = samples.size() / expected_sampling_rate; float duration = samples.size() / static_cast<float>(expected_sampling_rate);
std::cout << "wav filename: " << wav_filename << "\n"; fprintf(stderr, "wav filename: %s\n", wav_filename.c_str());
std::cout << "wav duration (s): " << duration << "\n"; fprintf(stderr, "wav duration (s): %.3f\n", duration);
auto begin = std::chrono::steady_clock::now(); auto begin = std::chrono::steady_clock::now();
std::cout << "Started!\n"; fprintf(stderr, "Started\n");
sherpa_onnx::FeatureExtractor feat_extractor; sherpa_onnx::FeatureExtractor feat_extractor;
feat_extractor.AcceptWaveform(expected_sampling_rate, samples.data(), feat_extractor.AcceptWaveform(expected_sampling_rate, samples.data(),
@@ -115,10 +116,10 @@ for a list of pre-trained models to download.
text += sym[hyp[i]]; text += sym[hyp[i]];
} }
std::cout << "Done!\n"; fprintf(stderr, "Done!\n");
std::cout << "Recognition result for " << wav_filename << "\n" fprintf(stderr, "Recognition result for %s:\n%s\n", wav_filename.c_str(),
<< text << "\n"; text.c_str());
auto end = std::chrono::steady_clock::now(); auto end = std::chrono::steady_clock::now();
float elapsed_seconds = float elapsed_seconds =
@@ -126,7 +127,7 @@ for a list of pre-trained models to download.
.count() / .count() /
1000.; 1000.;
std::cout << "num threads: " << config.num_threads << "\n"; fprintf(stderr, "num threads: %d\n", config.num_threads);
fprintf(stderr, "Elapsed seconds: %.3f s\n", elapsed_seconds); fprintf(stderr, "Elapsed seconds: %.3f s\n", elapsed_seconds);
float rtf = elapsed_seconds / duration; float rtf = elapsed_seconds / duration;