Support resampling (#77)

This commit is contained in:
Fangjun Kuang
2023-03-03 16:42:33 +08:00
committed by GitHub
parent 5f31b22c12
commit 9d8fddef01
10 changed files with 96 additions and 26 deletions

View File

@@ -11,6 +11,8 @@
#include <vector>
#include "kaldi-native-fbank/csrc/online-feature.h"
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/resample.h"
namespace sherpa_onnx {
@@ -50,6 +52,46 @@ class FeatureExtractor::Impl {
void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n) {
std::lock_guard<std::mutex> lock(mutex_);
if (resampler_) {
if (sampling_rate != resampler_->GetInputSamplingRate()) {
SHERPA_ONNX_LOGE(
"You changed the input sampling rate!! Expected: %d, given: "
"%d",
resampler_->GetInputSamplingRate(), sampling_rate);
exit(-1);
}
std::vector<float> samples;
resampler_->Resample(waveform, n, false, &samples);
fbank_->AcceptWaveform(opts_.frame_opts.samp_freq, samples.data(),
samples.size());
return;
}
if (sampling_rate != opts_.frame_opts.samp_freq) {
SHERPA_ONNX_LOGE(
"Creating a resampler:\n"
" in_sample_rate: %d\n"
" output_sample_rate: %d\n",
sampling_rate, static_cast<int32_t>(opts_.frame_opts.samp_freq));
float min_freq =
std::min<int32_t>(sampling_rate, opts_.frame_opts.samp_freq);
float lowpass_cutoff = 0.99 * 0.5 * min_freq;
int32_t lowpass_filter_width = 6;
resampler_ = std::make_unique<LinearResample>(
sampling_rate, opts_.frame_opts.samp_freq, lowpass_cutoff,
lowpass_filter_width);
std::vector<float> samples;
resampler_->Resample(waveform, n, false, &samples);
fbank_->AcceptWaveform(opts_.frame_opts.samp_freq, samples.data(),
samples.size());
return;
}
fbank_->AcceptWaveform(sampling_rate, waveform, n);
}
@@ -100,6 +142,7 @@ class FeatureExtractor::Impl {
std::unique_ptr<knf::OnlineFbank> fbank_;
knf::FbankOptions opts_;
mutable std::mutex mutex_;
std::unique_ptr<LinearResample> resampler_;
};
FeatureExtractor::FeatureExtractor(const FeatureExtractorConfig &config /*={}*/)

View File

@@ -29,9 +29,11 @@ class FeatureExtractor {
~FeatureExtractor();
/**
@param sampling_rate The sampling_rate of the input waveform. Should match
the one expected by the feature extractor.
@param waveform Pointer to a 1-D array of size n
@param sampling_rate The sampling_rate of the input waveform. If it does
not equal to config.sampling_rate, we will do
resampling inside.
@param waveform Pointer to a 1-D array of size n. It must be normalized to
the range [-1, 1].
@param n Number of entries in waveform
*/
void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n);

View File

@@ -16,7 +16,7 @@ class OnlineStream::Impl {
explicit Impl(const FeatureExtractorConfig &config)
: feat_extractor_(config) {}
void AcceptWaveform(float sampling_rate, const float *waveform, int32_t n) {
void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n) {
feat_extractor_.AcceptWaveform(sampling_rate, waveform, n);
}
@@ -67,7 +67,7 @@ OnlineStream::OnlineStream(const FeatureExtractorConfig &config /*= {}*/)
OnlineStream::~OnlineStream() = default;
void OnlineStream::AcceptWaveform(float sampling_rate, const float *waveform,
void OnlineStream::AcceptWaveform(int32_t sampling_rate, const float *waveform,
int32_t n) {
impl_->AcceptWaveform(sampling_rate, waveform, n);
}

View File

@@ -20,12 +20,14 @@ class OnlineStream {
~OnlineStream();
/**
@param sampling_rate The sampling_rate of the input waveform. Should match
the one expected by the feature extractor.
@param waveform Pointer to a 1-D array of size n
@param sampling_rate The sampling_rate of the input waveform. If it does
not equal to config.sampling_rate, we will do
resampling inside.
@param waveform Pointer to a 1-D array of size n. It must be normalized to
the range [-1, 1].
@param n Number of entries in waveform
*/
void AcceptWaveform(float sampling_rate, const float *waveform, int32_t n);
void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n);
/**
* InputFinished() tells the class you won't be providing any

View File

@@ -76,6 +76,7 @@ OnlineTransducerModifiedBeamSearchDecoder::GetEmptyResult() const {
std::vector<int64_t> blanks(context_size, blank_id);
Hypotheses blank_hyp({{blanks, 0}});
r.hyps = std::move(blank_hyp);
r.tokens = std::move(blanks);
return r;
}