diff --git a/python-api-examples/decode-file.py b/python-api-examples/decode-file.py index 368499f8..bad90b6f 100755 --- a/python-api-examples/decode-file.py +++ b/python-api-examples/decode-file.py @@ -78,8 +78,6 @@ def get_args(): def main(): - sample_rate = 16000 - args = get_args() assert_file_exists(args.encoder) assert_file_exists(args.decoder) @@ -95,12 +93,16 @@ def main(): decoder=args.decoder, joiner=args.joiner, num_threads=args.num_threads, - sample_rate=sample_rate, + sample_rate=16000, feature_dim=80, decoding_method=args.decoding_method, ) with wave.open(args.wave_filename) as f: - assert f.getframerate() == sample_rate, f.getframerate() + # If the wave file has a different sampling rate from the one + # expected by the model (16 kHz in our case), we will do + # resampling inside sherpa-onnx + wave_file_sample_rate = f.getframerate() + assert f.getnchannels() == 1, f.getnchannels() assert f.getsampwidth() == 2, f.getsampwidth() # it is in bytes num_samples = f.getnframes() @@ -110,17 +112,17 @@ def main(): samples_float32 = samples_float32 / 32768 - duration = len(samples_float32) / sample_rate + duration = len(samples_float32) / wave_file_sample_rate start_time = time.time() print("Started!") stream = recognizer.create_stream() - stream.accept_waveform(sample_rate, samples_float32) + stream.accept_waveform(wave_file_sample_rate, samples_float32) - tail_paddings = np.zeros(int(0.2 * sample_rate), dtype=np.float32) - stream.accept_waveform(sample_rate, tail_paddings) + tail_paddings = np.zeros(int(0.2 * wave_file_sample_rate), dtype=np.float32) + stream.accept_waveform(wave_file_sample_rate, tail_paddings) stream.input_finished() diff --git a/python-api-examples/speech-recognition-from-microphone-with-endpoint-detection.py b/python-api-examples/speech-recognition-from-microphone-with-endpoint-detection.py index 44d22549..9840b51e 100755 --- a/python-api-examples/speech-recognition-from-microphone-with-endpoint-detection.py +++ b/python-api-examples/speech-recognition-from-microphone-with-endpoint-detection.py @@ -100,7 +100,9 @@ def main(): recognizer = create_recognizer() print("Started! Please speak") - sample_rate = 16000 + # The model is using 16 kHz, we use 48 kHz here to demonstrate that + # sherpa-onnx will do resampling inside. + sample_rate = 48000 samples_per_read = int(0.1 * sample_rate) # 0.1 second = 100 ms last_result = "" stream = recognizer.create_stream() diff --git a/python-api-examples/speech-recognition-from-microphone.py b/python-api-examples/speech-recognition-from-microphone.py index bca4f2b0..e13b8d7f 100755 --- a/python-api-examples/speech-recognition-from-microphone.py +++ b/python-api-examples/speech-recognition-from-microphone.py @@ -92,9 +92,12 @@ def create_recognizer(): def main(): - print("Started! Please speak") recognizer = create_recognizer() - sample_rate = 16000 + print("Started! Please speak") + + # The model is using 16 kHz, we use 48 kHz here to demonstrate that + # sherpa-onnx will do resampling inside. + sample_rate = 48000 samples_per_read = int(0.1 * sample_rate) # 0.1 second = 100 ms last_result = "" stream = recognizer.create_stream() diff --git a/sherpa-onnx/c-api/c-api.h b/sherpa-onnx/c-api/c-api.h index 6fa2abe5..e979d9aa 100644 --- a/sherpa-onnx/c-api/c-api.h +++ b/sherpa-onnx/c-api/c-api.h @@ -115,8 +115,9 @@ void DestoryOnlineStream(SherpaOnnxOnlineStream *stream); /// decoding. /// /// @param stream A pointer returned by CreateOnlineStream(). -/// @param sample_rate Sampler rate of the input samples. It has to be 16 kHz -/// for models from icefall. +/// @param sample_rate Sample rate of the input samples. If it is different +/// from config.feat_config.sample_rate, we will do +/// resampling inside sherpa-onnx. /// @param samples A pointer to a 1-D array containing audio samples. /// The range of samples has to be normalized to [-1, 1]. /// @param n Number of elements in the samples array. diff --git a/sherpa-onnx/csrc/features.cc b/sherpa-onnx/csrc/features.cc index 31b789ab..eab137e7 100644 --- a/sherpa-onnx/csrc/features.cc +++ b/sherpa-onnx/csrc/features.cc @@ -11,6 +11,8 @@ #include #include "kaldi-native-fbank/csrc/online-feature.h" +#include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/csrc/resample.h" namespace sherpa_onnx { @@ -50,6 +52,46 @@ class FeatureExtractor::Impl { void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n) { std::lock_guard lock(mutex_); + + if (resampler_) { + if (sampling_rate != resampler_->GetInputSamplingRate()) { + SHERPA_ONNX_LOGE( + "You changed the input sampling rate!! Expected: %d, given: " + "%d", + resampler_->GetInputSamplingRate(), sampling_rate); + exit(-1); + } + + std::vector samples; + resampler_->Resample(waveform, n, false, &samples); + fbank_->AcceptWaveform(opts_.frame_opts.samp_freq, samples.data(), + samples.size()); + return; + } + + if (sampling_rate != opts_.frame_opts.samp_freq) { + SHERPA_ONNX_LOGE( + "Creating a resampler:\n" + " in_sample_rate: %d\n" + " output_sample_rate: %d\n", + sampling_rate, static_cast(opts_.frame_opts.samp_freq)); + + float min_freq = + std::min(sampling_rate, opts_.frame_opts.samp_freq); + float lowpass_cutoff = 0.99 * 0.5 * min_freq; + + int32_t lowpass_filter_width = 6; + resampler_ = std::make_unique( + sampling_rate, opts_.frame_opts.samp_freq, lowpass_cutoff, + lowpass_filter_width); + + std::vector samples; + resampler_->Resample(waveform, n, false, &samples); + fbank_->AcceptWaveform(opts_.frame_opts.samp_freq, samples.data(), + samples.size()); + return; + } + fbank_->AcceptWaveform(sampling_rate, waveform, n); } @@ -100,6 +142,7 @@ class FeatureExtractor::Impl { std::unique_ptr fbank_; knf::FbankOptions opts_; mutable std::mutex mutex_; + std::unique_ptr resampler_; }; FeatureExtractor::FeatureExtractor(const FeatureExtractorConfig &config /*={}*/) diff --git a/sherpa-onnx/csrc/features.h b/sherpa-onnx/csrc/features.h index cfdb7fa6..831f221e 100644 --- a/sherpa-onnx/csrc/features.h +++ b/sherpa-onnx/csrc/features.h @@ -29,9 +29,11 @@ class FeatureExtractor { ~FeatureExtractor(); /** - @param sampling_rate The sampling_rate of the input waveform. Should match - the one expected by the feature extractor. - @param waveform Pointer to a 1-D array of size n + @param sampling_rate The sampling_rate of the input waveform. If it does + not equal to config.sampling_rate, we will do + resampling inside. + @param waveform Pointer to a 1-D array of size n. It must be normalized to + the range [-1, 1]. @param n Number of entries in waveform */ void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n); diff --git a/sherpa-onnx/csrc/online-stream.cc b/sherpa-onnx/csrc/online-stream.cc index 1ed1588f..da397d67 100644 --- a/sherpa-onnx/csrc/online-stream.cc +++ b/sherpa-onnx/csrc/online-stream.cc @@ -16,7 +16,7 @@ class OnlineStream::Impl { explicit Impl(const FeatureExtractorConfig &config) : feat_extractor_(config) {} - void AcceptWaveform(float sampling_rate, const float *waveform, int32_t n) { + void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n) { feat_extractor_.AcceptWaveform(sampling_rate, waveform, n); } @@ -67,7 +67,7 @@ OnlineStream::OnlineStream(const FeatureExtractorConfig &config /*= {}*/) OnlineStream::~OnlineStream() = default; -void OnlineStream::AcceptWaveform(float sampling_rate, const float *waveform, +void OnlineStream::AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n) { impl_->AcceptWaveform(sampling_rate, waveform, n); } diff --git a/sherpa-onnx/csrc/online-stream.h b/sherpa-onnx/csrc/online-stream.h index 0bba1847..32fe1248 100644 --- a/sherpa-onnx/csrc/online-stream.h +++ b/sherpa-onnx/csrc/online-stream.h @@ -20,12 +20,14 @@ class OnlineStream { ~OnlineStream(); /** - @param sampling_rate The sampling_rate of the input waveform. Should match - the one expected by the feature extractor. - @param waveform Pointer to a 1-D array of size n + @param sampling_rate The sampling_rate of the input waveform. If it does + not equal to config.sampling_rate, we will do + resampling inside. + @param waveform Pointer to a 1-D array of size n. It must be normalized to + the range [-1, 1]. @param n Number of entries in waveform */ - void AcceptWaveform(float sampling_rate, const float *waveform, int32_t n); + void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n); /** * InputFinished() tells the class you won't be providing any diff --git a/sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.cc b/sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.cc index 2d9825d9..c7c00b5b 100644 --- a/sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.cc +++ b/sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.cc @@ -76,6 +76,7 @@ OnlineTransducerModifiedBeamSearchDecoder::GetEmptyResult() const { std::vector blanks(context_size, blank_id); Hypotheses blank_hyp({{blanks, 0}}); r.hyps = std::move(blank_hyp); + r.tokens = std::move(blanks); return r; } diff --git a/sherpa-onnx/python/csrc/online-stream.cc b/sherpa-onnx/python/csrc/online-stream.cc index 06a46a59..411354a0 100644 --- a/sherpa-onnx/python/csrc/online-stream.cc +++ b/sherpa-onnx/python/csrc/online-stream.cc @@ -8,13 +8,27 @@ namespace sherpa_onnx { +constexpr const char *kAcceptWaveformUsage = R"( +Process audio samples. + +Args: + sample_rate: + Sample rate of the input samples. If it is different from the one + expected by the model, we will do resampling inside. + waveform: + A 1-D float32 tensor containing audio samples. It must be normalized + to the range [-1, 1]. +)"; + void PybindOnlineStream(py::module *m) { using PyClass = OnlineStream; py::class_(*m, "OnlineStream") - .def("accept_waveform", - [](PyClass &self, float sample_rate, py::array_t waveform) { - self.AcceptWaveform(sample_rate, waveform.data(), waveform.size()); - }) + .def( + "accept_waveform", + [](PyClass &self, float sample_rate, py::array_t waveform) { + self.AcceptWaveform(sample_rate, waveform.data(), waveform.size()); + }, + py::arg("sample_rate"), py::arg("waveform"), kAcceptWaveformUsage) .def("input_finished", &PyClass::InputFinished); }