Refactor feature extractor (#26)
This commit is contained in:
@@ -6,52 +6,50 @@
|
|||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
#include <mutex> // NOLINT
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "kaldi-native-fbank/csrc/online-feature.h"
|
||||||
|
|
||||||
namespace sherpa_onnx {
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
FeatureExtractor::FeatureExtractor() {
|
class FeatureExtractor::Impl {
|
||||||
|
public:
|
||||||
|
Impl(int32_t sampling_rate, int32_t feature_dim) {
|
||||||
opts_.frame_opts.dither = 0;
|
opts_.frame_opts.dither = 0;
|
||||||
opts_.frame_opts.snip_edges = false;
|
opts_.frame_opts.snip_edges = false;
|
||||||
opts_.frame_opts.samp_freq = 16000;
|
opts_.frame_opts.samp_freq = sampling_rate;
|
||||||
|
|
||||||
// cache 100 seconds of feature frames, which is more than enough
|
// cache 100 seconds of feature frames, which is more than enough
|
||||||
// for real needs
|
// for real needs
|
||||||
opts_.frame_opts.max_feature_vectors = 100 * 100;
|
opts_.frame_opts.max_feature_vectors = 100 * 100;
|
||||||
|
|
||||||
opts_.mel_opts.num_bins = 80; // feature dim
|
opts_.mel_opts.num_bins = feature_dim;
|
||||||
|
|
||||||
fbank_ = std::make_unique<knf::OnlineFbank>(opts_);
|
fbank_ = std::make_unique<knf::OnlineFbank>(opts_);
|
||||||
}
|
}
|
||||||
|
|
||||||
FeatureExtractor::FeatureExtractor(const knf::FbankOptions &opts)
|
void AcceptWaveform(float sampling_rate, const float *waveform, int32_t n) {
|
||||||
: opts_(opts) {
|
|
||||||
fbank_ = std::make_unique<knf::OnlineFbank>(opts_);
|
|
||||||
}
|
|
||||||
|
|
||||||
void FeatureExtractor::AcceptWaveform(float sampling_rate,
|
|
||||||
const float *waveform, int32_t n) {
|
|
||||||
std::lock_guard<std::mutex> lock(mutex_);
|
std::lock_guard<std::mutex> lock(mutex_);
|
||||||
fbank_->AcceptWaveform(sampling_rate, waveform, n);
|
fbank_->AcceptWaveform(sampling_rate, waveform, n);
|
||||||
}
|
}
|
||||||
|
|
||||||
void FeatureExtractor::InputFinished() {
|
void InputFinished() {
|
||||||
std::lock_guard<std::mutex> lock(mutex_);
|
std::lock_guard<std::mutex> lock(mutex_);
|
||||||
fbank_->InputFinished();
|
fbank_->InputFinished();
|
||||||
}
|
}
|
||||||
|
|
||||||
int32_t FeatureExtractor::NumFramesReady() const {
|
int32_t NumFramesReady() const {
|
||||||
std::lock_guard<std::mutex> lock(mutex_);
|
std::lock_guard<std::mutex> lock(mutex_);
|
||||||
return fbank_->NumFramesReady();
|
return fbank_->NumFramesReady();
|
||||||
}
|
}
|
||||||
|
|
||||||
bool FeatureExtractor::IsLastFrame(int32_t frame) const {
|
bool IsLastFrame(int32_t frame) const {
|
||||||
std::lock_guard<std::mutex> lock(mutex_);
|
std::lock_guard<std::mutex> lock(mutex_);
|
||||||
return fbank_->IsLastFrame(frame);
|
return fbank_->IsLastFrame(frame);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<float> FeatureExtractor::GetFrames(int32_t frame_index,
|
std::vector<float> GetFrames(int32_t frame_index, int32_t n) const {
|
||||||
int32_t n) const {
|
|
||||||
if (frame_index + n > NumFramesReady()) {
|
if (frame_index + n > NumFramesReady()) {
|
||||||
fprintf(stderr, "%d + %d > %d\n", frame_index, n, NumFramesReady());
|
fprintf(stderr, "%d + %d > %d\n", frame_index, n, NumFramesReady());
|
||||||
exit(-1);
|
exit(-1);
|
||||||
@@ -70,10 +68,46 @@ std::vector<float> FeatureExtractor::GetFrames(int32_t frame_index,
|
|||||||
}
|
}
|
||||||
|
|
||||||
return features;
|
return features;
|
||||||
|
}
|
||||||
|
|
||||||
|
void Reset() { fbank_ = std::make_unique<knf::OnlineFbank>(opts_); }
|
||||||
|
|
||||||
|
int32_t FeatureDim() const { return opts_.mel_opts.num_bins; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::unique_ptr<knf::OnlineFbank> fbank_;
|
||||||
|
knf::FbankOptions opts_;
|
||||||
|
mutable std::mutex mutex_;
|
||||||
|
};
|
||||||
|
|
||||||
|
FeatureExtractor::FeatureExtractor(int32_t sampling_rate /*=16000*/,
|
||||||
|
int32_t feature_dim /*=80*/)
|
||||||
|
: impl_(std::make_unique<Impl>(sampling_rate, feature_dim)) {}
|
||||||
|
|
||||||
|
FeatureExtractor::~FeatureExtractor() = default;
|
||||||
|
|
||||||
|
void FeatureExtractor::AcceptWaveform(float sampling_rate,
|
||||||
|
const float *waveform, int32_t n) {
|
||||||
|
impl_->AcceptWaveform(sampling_rate, waveform, n);
|
||||||
}
|
}
|
||||||
|
|
||||||
void FeatureExtractor::Reset() {
|
void FeatureExtractor::InputFinished() { impl_->InputFinished(); }
|
||||||
fbank_ = std::make_unique<knf::OnlineFbank>(opts_);
|
|
||||||
|
int32_t FeatureExtractor::NumFramesReady() const {
|
||||||
|
return impl_->NumFramesReady();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool FeatureExtractor::IsLastFrame(int32_t frame) const {
|
||||||
|
return impl_->IsLastFrame(frame);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<float> FeatureExtractor::GetFrames(int32_t frame_index,
|
||||||
|
int32_t n) const {
|
||||||
|
return impl_->GetFrames(frame_index, n);
|
||||||
|
}
|
||||||
|
|
||||||
|
void FeatureExtractor::Reset() { impl_->Reset(); }
|
||||||
|
|
||||||
|
int32_t FeatureExtractor::FeatureDim() const { return impl_->FeatureDim(); }
|
||||||
|
|
||||||
} // namespace sherpa_onnx
|
} // namespace sherpa_onnx
|
||||||
|
|||||||
@@ -6,17 +6,19 @@
|
|||||||
#define SHERPA_ONNX_CSRC_FEATURES_H_
|
#define SHERPA_ONNX_CSRC_FEATURES_H_
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <mutex> // NOLINT
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "kaldi-native-fbank/csrc/online-feature.h"
|
|
||||||
|
|
||||||
namespace sherpa_onnx {
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
class FeatureExtractor {
|
class FeatureExtractor {
|
||||||
public:
|
public:
|
||||||
FeatureExtractor();
|
/**
|
||||||
explicit FeatureExtractor(const knf::FbankOptions &fbank_opts);
|
* @param sampling_rate Sampling rate of the data used to train the model.
|
||||||
|
* @param feature_dim Dimension of the features used to train the model.
|
||||||
|
*/
|
||||||
|
explicit FeatureExtractor(int32_t sampling_rate = 16000,
|
||||||
|
int32_t feature_dim = 80);
|
||||||
|
~FeatureExtractor();
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@param sampling_rate The sampling_rate of the input waveform. Should match
|
@param sampling_rate The sampling_rate of the input waveform. Should match
|
||||||
@@ -48,12 +50,13 @@ class FeatureExtractor {
|
|||||||
std::vector<float> GetFrames(int32_t frame_index, int32_t n) const;
|
std::vector<float> GetFrames(int32_t frame_index, int32_t n) const;
|
||||||
|
|
||||||
void Reset();
|
void Reset();
|
||||||
int32_t FeatureDim() const { return opts_.mel_opts.num_bins; }
|
|
||||||
|
/// Return feature dim of this extractor
|
||||||
|
int32_t FeatureDim() const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::unique_ptr<knf::OnlineFbank> fbank_;
|
class Impl;
|
||||||
knf::FbankOptions opts_;
|
std::unique_ptr<Impl> impl_;
|
||||||
mutable std::mutex mutex_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace sherpa_onnx
|
} // namespace sherpa_onnx
|
||||||
|
|||||||
@@ -2,8 +2,9 @@
|
|||||||
//
|
//
|
||||||
// Copyright (c) 2022-2023 Xiaomi Corporation
|
// Copyright (c) 2022-2023 Xiaomi Corporation
|
||||||
|
|
||||||
|
#include <stdio.h>
|
||||||
|
|
||||||
#include <chrono> // NOLINT
|
#include <chrono> // NOLINT
|
||||||
#include <iostream>
|
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
@@ -30,14 +31,14 @@ Please refer to
|
|||||||
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
|
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
|
||||||
for a list of pre-trained models to download.
|
for a list of pre-trained models to download.
|
||||||
)usage";
|
)usage";
|
||||||
std::cerr << usage << "\n";
|
fprintf(stderr, "%s\n", usage);
|
||||||
|
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string tokens = argv[1];
|
std::string tokens = argv[1];
|
||||||
sherpa_onnx::OnlineTransducerModelConfig config;
|
sherpa_onnx::OnlineTransducerModelConfig config;
|
||||||
config.debug = true;
|
config.debug = false;
|
||||||
config.encoder_filename = argv[2];
|
config.encoder_filename = argv[2];
|
||||||
config.decoder_filename = argv[3];
|
config.decoder_filename = argv[3];
|
||||||
config.joiner_filename = argv[4];
|
config.joiner_filename = argv[4];
|
||||||
@@ -47,7 +48,7 @@ for a list of pre-trained models to download.
|
|||||||
if (argc == 7) {
|
if (argc == 7) {
|
||||||
config.num_threads = atoi(argv[6]);
|
config.num_threads = atoi(argv[6]);
|
||||||
}
|
}
|
||||||
std::cout << config.ToString().c_str() << "\n";
|
fprintf(stderr, "%s\n", config.ToString().c_str());
|
||||||
|
|
||||||
auto model = sherpa_onnx::OnlineTransducerModel::Create(config);
|
auto model = sherpa_onnx::OnlineTransducerModel::Create(config);
|
||||||
|
|
||||||
@@ -72,17 +73,17 @@ for a list of pre-trained models to download.
|
|||||||
sherpa_onnx::ReadWave(wav_filename, expected_sampling_rate, &is_ok);
|
sherpa_onnx::ReadWave(wav_filename, expected_sampling_rate, &is_ok);
|
||||||
|
|
||||||
if (!is_ok) {
|
if (!is_ok) {
|
||||||
std::cerr << "Failed to read " << wav_filename << "\n";
|
fprintf(stderr, "Failed to read %s\n", wav_filename.c_str());
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
|
|
||||||
const float duration = samples.size() / expected_sampling_rate;
|
float duration = samples.size() / static_cast<float>(expected_sampling_rate);
|
||||||
|
|
||||||
std::cout << "wav filename: " << wav_filename << "\n";
|
fprintf(stderr, "wav filename: %s\n", wav_filename.c_str());
|
||||||
std::cout << "wav duration (s): " << duration << "\n";
|
fprintf(stderr, "wav duration (s): %.3f\n", duration);
|
||||||
|
|
||||||
auto begin = std::chrono::steady_clock::now();
|
auto begin = std::chrono::steady_clock::now();
|
||||||
std::cout << "Started!\n";
|
fprintf(stderr, "Started\n");
|
||||||
|
|
||||||
sherpa_onnx::FeatureExtractor feat_extractor;
|
sherpa_onnx::FeatureExtractor feat_extractor;
|
||||||
feat_extractor.AcceptWaveform(expected_sampling_rate, samples.data(),
|
feat_extractor.AcceptWaveform(expected_sampling_rate, samples.data(),
|
||||||
@@ -115,10 +116,10 @@ for a list of pre-trained models to download.
|
|||||||
text += sym[hyp[i]];
|
text += sym[hyp[i]];
|
||||||
}
|
}
|
||||||
|
|
||||||
std::cout << "Done!\n";
|
fprintf(stderr, "Done!\n");
|
||||||
|
|
||||||
std::cout << "Recognition result for " << wav_filename << "\n"
|
fprintf(stderr, "Recognition result for %s:\n%s\n", wav_filename.c_str(),
|
||||||
<< text << "\n";
|
text.c_str());
|
||||||
|
|
||||||
auto end = std::chrono::steady_clock::now();
|
auto end = std::chrono::steady_clock::now();
|
||||||
float elapsed_seconds =
|
float elapsed_seconds =
|
||||||
@@ -126,7 +127,7 @@ for a list of pre-trained models to download.
|
|||||||
.count() /
|
.count() /
|
||||||
1000.;
|
1000.;
|
||||||
|
|
||||||
std::cout << "num threads: " << config.num_threads << "\n";
|
fprintf(stderr, "num threads: %d\n", config.num_threads);
|
||||||
|
|
||||||
fprintf(stderr, "Elapsed seconds: %.3f s\n", elapsed_seconds);
|
fprintf(stderr, "Elapsed seconds: %.3f s\n", elapsed_seconds);
|
||||||
float rtf = elapsed_seconds / duration;
|
float rtf = elapsed_seconds / duration;
|
||||||
|
|||||||
Reference in New Issue
Block a user