Support resampling (#77)
This commit is contained in:
@@ -78,8 +78,6 @@ def get_args():
|
|||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
sample_rate = 16000
|
|
||||||
|
|
||||||
args = get_args()
|
args = get_args()
|
||||||
assert_file_exists(args.encoder)
|
assert_file_exists(args.encoder)
|
||||||
assert_file_exists(args.decoder)
|
assert_file_exists(args.decoder)
|
||||||
@@ -95,12 +93,16 @@ def main():
|
|||||||
decoder=args.decoder,
|
decoder=args.decoder,
|
||||||
joiner=args.joiner,
|
joiner=args.joiner,
|
||||||
num_threads=args.num_threads,
|
num_threads=args.num_threads,
|
||||||
sample_rate=sample_rate,
|
sample_rate=16000,
|
||||||
feature_dim=80,
|
feature_dim=80,
|
||||||
decoding_method=args.decoding_method,
|
decoding_method=args.decoding_method,
|
||||||
)
|
)
|
||||||
with wave.open(args.wave_filename) as f:
|
with wave.open(args.wave_filename) as f:
|
||||||
assert f.getframerate() == sample_rate, f.getframerate()
|
# If the wave file has a different sampling rate from the one
|
||||||
|
# expected by the model (16 kHz in our case), we will do
|
||||||
|
# resampling inside sherpa-onnx
|
||||||
|
wave_file_sample_rate = f.getframerate()
|
||||||
|
|
||||||
assert f.getnchannels() == 1, f.getnchannels()
|
assert f.getnchannels() == 1, f.getnchannels()
|
||||||
assert f.getsampwidth() == 2, f.getsampwidth() # it is in bytes
|
assert f.getsampwidth() == 2, f.getsampwidth() # it is in bytes
|
||||||
num_samples = f.getnframes()
|
num_samples = f.getnframes()
|
||||||
@@ -110,17 +112,17 @@ def main():
|
|||||||
|
|
||||||
samples_float32 = samples_float32 / 32768
|
samples_float32 = samples_float32 / 32768
|
||||||
|
|
||||||
duration = len(samples_float32) / sample_rate
|
duration = len(samples_float32) / wave_file_sample_rate
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
print("Started!")
|
print("Started!")
|
||||||
|
|
||||||
stream = recognizer.create_stream()
|
stream = recognizer.create_stream()
|
||||||
|
|
||||||
stream.accept_waveform(sample_rate, samples_float32)
|
stream.accept_waveform(wave_file_sample_rate, samples_float32)
|
||||||
|
|
||||||
tail_paddings = np.zeros(int(0.2 * sample_rate), dtype=np.float32)
|
tail_paddings = np.zeros(int(0.2 * wave_file_sample_rate), dtype=np.float32)
|
||||||
stream.accept_waveform(sample_rate, tail_paddings)
|
stream.accept_waveform(wave_file_sample_rate, tail_paddings)
|
||||||
|
|
||||||
stream.input_finished()
|
stream.input_finished()
|
||||||
|
|
||||||
|
|||||||
@@ -100,7 +100,9 @@ def main():
|
|||||||
recognizer = create_recognizer()
|
recognizer = create_recognizer()
|
||||||
print("Started! Please speak")
|
print("Started! Please speak")
|
||||||
|
|
||||||
sample_rate = 16000
|
# The model is using 16 kHz, we use 48 kHz here to demonstrate that
|
||||||
|
# sherpa-onnx will do resampling inside.
|
||||||
|
sample_rate = 48000
|
||||||
samples_per_read = int(0.1 * sample_rate) # 0.1 second = 100 ms
|
samples_per_read = int(0.1 * sample_rate) # 0.1 second = 100 ms
|
||||||
last_result = ""
|
last_result = ""
|
||||||
stream = recognizer.create_stream()
|
stream = recognizer.create_stream()
|
||||||
|
|||||||
@@ -92,9 +92,12 @@ def create_recognizer():
|
|||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
print("Started! Please speak")
|
|
||||||
recognizer = create_recognizer()
|
recognizer = create_recognizer()
|
||||||
sample_rate = 16000
|
print("Started! Please speak")
|
||||||
|
|
||||||
|
# The model is using 16 kHz, we use 48 kHz here to demonstrate that
|
||||||
|
# sherpa-onnx will do resampling inside.
|
||||||
|
sample_rate = 48000
|
||||||
samples_per_read = int(0.1 * sample_rate) # 0.1 second = 100 ms
|
samples_per_read = int(0.1 * sample_rate) # 0.1 second = 100 ms
|
||||||
last_result = ""
|
last_result = ""
|
||||||
stream = recognizer.create_stream()
|
stream = recognizer.create_stream()
|
||||||
|
|||||||
@@ -115,8 +115,9 @@ void DestoryOnlineStream(SherpaOnnxOnlineStream *stream);
|
|||||||
/// decoding.
|
/// decoding.
|
||||||
///
|
///
|
||||||
/// @param stream A pointer returned by CreateOnlineStream().
|
/// @param stream A pointer returned by CreateOnlineStream().
|
||||||
/// @param sample_rate Sampler rate of the input samples. It has to be 16 kHz
|
/// @param sample_rate Sample rate of the input samples. If it is different
|
||||||
/// for models from icefall.
|
/// from config.feat_config.sample_rate, we will do
|
||||||
|
/// resampling inside sherpa-onnx.
|
||||||
/// @param samples A pointer to a 1-D array containing audio samples.
|
/// @param samples A pointer to a 1-D array containing audio samples.
|
||||||
/// The range of samples has to be normalized to [-1, 1].
|
/// The range of samples has to be normalized to [-1, 1].
|
||||||
/// @param n Number of elements in the samples array.
|
/// @param n Number of elements in the samples array.
|
||||||
|
|||||||
@@ -11,6 +11,8 @@
|
|||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "kaldi-native-fbank/csrc/online-feature.h"
|
#include "kaldi-native-fbank/csrc/online-feature.h"
|
||||||
|
#include "sherpa-onnx/csrc/macros.h"
|
||||||
|
#include "sherpa-onnx/csrc/resample.h"
|
||||||
|
|
||||||
namespace sherpa_onnx {
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
@@ -50,6 +52,46 @@ class FeatureExtractor::Impl {
|
|||||||
|
|
||||||
void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n) {
|
void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n) {
|
||||||
std::lock_guard<std::mutex> lock(mutex_);
|
std::lock_guard<std::mutex> lock(mutex_);
|
||||||
|
|
||||||
|
if (resampler_) {
|
||||||
|
if (sampling_rate != resampler_->GetInputSamplingRate()) {
|
||||||
|
SHERPA_ONNX_LOGE(
|
||||||
|
"You changed the input sampling rate!! Expected: %d, given: "
|
||||||
|
"%d",
|
||||||
|
resampler_->GetInputSamplingRate(), sampling_rate);
|
||||||
|
exit(-1);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<float> samples;
|
||||||
|
resampler_->Resample(waveform, n, false, &samples);
|
||||||
|
fbank_->AcceptWaveform(opts_.frame_opts.samp_freq, samples.data(),
|
||||||
|
samples.size());
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (sampling_rate != opts_.frame_opts.samp_freq) {
|
||||||
|
SHERPA_ONNX_LOGE(
|
||||||
|
"Creating a resampler:\n"
|
||||||
|
" in_sample_rate: %d\n"
|
||||||
|
" output_sample_rate: %d\n",
|
||||||
|
sampling_rate, static_cast<int32_t>(opts_.frame_opts.samp_freq));
|
||||||
|
|
||||||
|
float min_freq =
|
||||||
|
std::min<int32_t>(sampling_rate, opts_.frame_opts.samp_freq);
|
||||||
|
float lowpass_cutoff = 0.99 * 0.5 * min_freq;
|
||||||
|
|
||||||
|
int32_t lowpass_filter_width = 6;
|
||||||
|
resampler_ = std::make_unique<LinearResample>(
|
||||||
|
sampling_rate, opts_.frame_opts.samp_freq, lowpass_cutoff,
|
||||||
|
lowpass_filter_width);
|
||||||
|
|
||||||
|
std::vector<float> samples;
|
||||||
|
resampler_->Resample(waveform, n, false, &samples);
|
||||||
|
fbank_->AcceptWaveform(opts_.frame_opts.samp_freq, samples.data(),
|
||||||
|
samples.size());
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
fbank_->AcceptWaveform(sampling_rate, waveform, n);
|
fbank_->AcceptWaveform(sampling_rate, waveform, n);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -100,6 +142,7 @@ class FeatureExtractor::Impl {
|
|||||||
std::unique_ptr<knf::OnlineFbank> fbank_;
|
std::unique_ptr<knf::OnlineFbank> fbank_;
|
||||||
knf::FbankOptions opts_;
|
knf::FbankOptions opts_;
|
||||||
mutable std::mutex mutex_;
|
mutable std::mutex mutex_;
|
||||||
|
std::unique_ptr<LinearResample> resampler_;
|
||||||
};
|
};
|
||||||
|
|
||||||
FeatureExtractor::FeatureExtractor(const FeatureExtractorConfig &config /*={}*/)
|
FeatureExtractor::FeatureExtractor(const FeatureExtractorConfig &config /*={}*/)
|
||||||
|
|||||||
@@ -29,9 +29,11 @@ class FeatureExtractor {
|
|||||||
~FeatureExtractor();
|
~FeatureExtractor();
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@param sampling_rate The sampling_rate of the input waveform. Should match
|
@param sampling_rate The sampling_rate of the input waveform. If it does
|
||||||
the one expected by the feature extractor.
|
not equal to config.sampling_rate, we will do
|
||||||
@param waveform Pointer to a 1-D array of size n
|
resampling inside.
|
||||||
|
@param waveform Pointer to a 1-D array of size n. It must be normalized to
|
||||||
|
the range [-1, 1].
|
||||||
@param n Number of entries in waveform
|
@param n Number of entries in waveform
|
||||||
*/
|
*/
|
||||||
void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n);
|
void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n);
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ class OnlineStream::Impl {
|
|||||||
explicit Impl(const FeatureExtractorConfig &config)
|
explicit Impl(const FeatureExtractorConfig &config)
|
||||||
: feat_extractor_(config) {}
|
: feat_extractor_(config) {}
|
||||||
|
|
||||||
void AcceptWaveform(float sampling_rate, const float *waveform, int32_t n) {
|
void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n) {
|
||||||
feat_extractor_.AcceptWaveform(sampling_rate, waveform, n);
|
feat_extractor_.AcceptWaveform(sampling_rate, waveform, n);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -67,7 +67,7 @@ OnlineStream::OnlineStream(const FeatureExtractorConfig &config /*= {}*/)
|
|||||||
|
|
||||||
OnlineStream::~OnlineStream() = default;
|
OnlineStream::~OnlineStream() = default;
|
||||||
|
|
||||||
void OnlineStream::AcceptWaveform(float sampling_rate, const float *waveform,
|
void OnlineStream::AcceptWaveform(int32_t sampling_rate, const float *waveform,
|
||||||
int32_t n) {
|
int32_t n) {
|
||||||
impl_->AcceptWaveform(sampling_rate, waveform, n);
|
impl_->AcceptWaveform(sampling_rate, waveform, n);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -20,12 +20,14 @@ class OnlineStream {
|
|||||||
~OnlineStream();
|
~OnlineStream();
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@param sampling_rate The sampling_rate of the input waveform. Should match
|
@param sampling_rate The sampling_rate of the input waveform. If it does
|
||||||
the one expected by the feature extractor.
|
not equal to config.sampling_rate, we will do
|
||||||
@param waveform Pointer to a 1-D array of size n
|
resampling inside.
|
||||||
|
@param waveform Pointer to a 1-D array of size n. It must be normalized to
|
||||||
|
the range [-1, 1].
|
||||||
@param n Number of entries in waveform
|
@param n Number of entries in waveform
|
||||||
*/
|
*/
|
||||||
void AcceptWaveform(float sampling_rate, const float *waveform, int32_t n);
|
void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* InputFinished() tells the class you won't be providing any
|
* InputFinished() tells the class you won't be providing any
|
||||||
|
|||||||
@@ -76,6 +76,7 @@ OnlineTransducerModifiedBeamSearchDecoder::GetEmptyResult() const {
|
|||||||
std::vector<int64_t> blanks(context_size, blank_id);
|
std::vector<int64_t> blanks(context_size, blank_id);
|
||||||
Hypotheses blank_hyp({{blanks, 0}});
|
Hypotheses blank_hyp({{blanks, 0}});
|
||||||
r.hyps = std::move(blank_hyp);
|
r.hyps = std::move(blank_hyp);
|
||||||
|
r.tokens = std::move(blanks);
|
||||||
return r;
|
return r;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -8,13 +8,27 @@
|
|||||||
|
|
||||||
namespace sherpa_onnx {
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
constexpr const char *kAcceptWaveformUsage = R"(
|
||||||
|
Process audio samples.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sample_rate:
|
||||||
|
Sample rate of the input samples. If it is different from the one
|
||||||
|
expected by the model, we will do resampling inside.
|
||||||
|
waveform:
|
||||||
|
A 1-D float32 tensor containing audio samples. It must be normalized
|
||||||
|
to the range [-1, 1].
|
||||||
|
)";
|
||||||
|
|
||||||
void PybindOnlineStream(py::module *m) {
|
void PybindOnlineStream(py::module *m) {
|
||||||
using PyClass = OnlineStream;
|
using PyClass = OnlineStream;
|
||||||
py::class_<PyClass>(*m, "OnlineStream")
|
py::class_<PyClass>(*m, "OnlineStream")
|
||||||
.def("accept_waveform",
|
.def(
|
||||||
[](PyClass &self, float sample_rate, py::array_t<float> waveform) {
|
"accept_waveform",
|
||||||
self.AcceptWaveform(sample_rate, waveform.data(), waveform.size());
|
[](PyClass &self, float sample_rate, py::array_t<float> waveform) {
|
||||||
})
|
self.AcceptWaveform(sample_rate, waveform.data(), waveform.size());
|
||||||
|
},
|
||||||
|
py::arg("sample_rate"), py::arg("waveform"), kAcceptWaveformUsage)
|
||||||
.def("input_finished", &PyClass::InputFinished);
|
.def("input_finished", &PyClass::InputFinished);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user