diff --git a/cmake/cmake_extension.py b/cmake/cmake_extension.py index 3d0dbef8..4b9028ba 100644 --- a/cmake/cmake_extension.py +++ b/cmake/cmake_extension.py @@ -64,6 +64,7 @@ def get_binaries(): "sherpa-onnx-online-websocket-server", "sherpa-onnx-vad-microphone", "sherpa-onnx-vad-microphone-offline-asr", + "sherpa-onnx-vad-with-offline-asr", ] if enable_alsa(): diff --git a/sherpa-onnx/csrc/CMakeLists.txt b/sherpa-onnx/csrc/CMakeLists.txt index 2bd38404..1415c2ae 100644 --- a/sherpa-onnx/csrc/CMakeLists.txt +++ b/sherpa-onnx/csrc/CMakeLists.txt @@ -452,6 +452,10 @@ if(SHERPA_ONNX_ENABLE_PORTAUDIO AND SHERPA_ONNX_ENABLE_BINARY) microphone.cc ) + add_executable(sherpa-onnx-vad-with-offline-asr + sherpa-onnx-vad-with-offline-asr.cc + ) + add_executable(sherpa-onnx-vad-microphone-offline-asr sherpa-onnx-vad-microphone-offline-asr.cc microphone.cc @@ -475,6 +479,7 @@ if(SHERPA_ONNX_ENABLE_PORTAUDIO AND SHERPA_ONNX_ENABLE_BINARY) sherpa-onnx-microphone-offline-audio-tagging sherpa-onnx-vad-microphone sherpa-onnx-vad-microphone-offline-asr + sherpa-onnx-vad-with-offline-asr ) if(SHERPA_ONNX_ENABLE_TTS) list(APPEND exes diff --git a/sherpa-onnx/csrc/online-ebranchformer-transducer-model.cc b/sherpa-onnx/csrc/online-ebranchformer-transducer-model.cc index 84d81e8e..79a7ffdd 100644 --- a/sherpa-onnx/csrc/online-ebranchformer-transducer-model.cc +++ b/sherpa-onnx/csrc/online-ebranchformer-transducer-model.cc @@ -85,9 +85,8 @@ OnlineEbranchformerTransducerModel::OnlineEbranchformerTransducerModel( } } - void OnlineEbranchformerTransducerModel::InitEncoder(void *model_data, - size_t model_data_length) { + size_t model_data_length) { encoder_sess_ = std::make_unique( env_, model_data, model_data_length, encoder_sess_opts_); @@ -153,9 +152,8 @@ void OnlineEbranchformerTransducerModel::InitEncoder(void *model_data, } } - void OnlineEbranchformerTransducerModel::InitDecoder(void *model_data, - size_t model_data_length) { + size_t model_data_length) { decoder_sess_ = std::make_unique( env_, model_data, model_data_length, decoder_sess_opts_); @@ -180,7 +178,7 @@ void OnlineEbranchformerTransducerModel::InitDecoder(void *model_data, } void OnlineEbranchformerTransducerModel::InitJoiner(void *model_data, - size_t model_data_length) { + size_t model_data_length) { joiner_sess_ = std::make_unique( env_, model_data, model_data_length, joiner_sess_opts_); @@ -200,7 +198,6 @@ void OnlineEbranchformerTransducerModel::InitJoiner(void *model_data, } } - std::vector OnlineEbranchformerTransducerModel::StackStates( const std::vector> &states) const { int32_t batch_size = static_cast(states.size()); @@ -215,28 +212,28 @@ std::vector OnlineEbranchformerTransducerModel::StackStates( ans.reserve(num_states); for (int32_t i = 0; i != num_hidden_layers_; ++i) { - { // cached_key + { // cached_key for (int32_t n = 0; n != batch_size; ++n) { buf[n] = &states[n][4 * i]; } auto v = Cat(allocator, buf, /* axis */ 0); ans.push_back(std::move(v)); } - { // cached_value + { // cached_value for (int32_t n = 0; n != batch_size; ++n) { buf[n] = &states[n][4 * i + 1]; } auto v = Cat(allocator, buf, 0); ans.push_back(std::move(v)); } - { // cached_conv + { // cached_conv for (int32_t n = 0; n != batch_size; ++n) { buf[n] = &states[n][4 * i + 2]; } auto v = Cat(allocator, buf, 0); ans.push_back(std::move(v)); } - { // cached_conv_fusion + { // cached_conv_fusion for (int32_t n = 0; n != batch_size; ++n) { buf[n] = &states[n][4 * i + 3]; } @@ -245,7 +242,7 @@ std::vector OnlineEbranchformerTransducerModel::StackStates( } } - { // processed_lens + { // processed_lens for (int32_t n = 0; n != batch_size; ++n) { buf[n] = &states[n][num_states - 1]; } @@ -256,11 +253,9 @@ std::vector OnlineEbranchformerTransducerModel::StackStates( return ans; } - std::vector> OnlineEbranchformerTransducerModel::UnStackStates( const std::vector &states) const { - assert(static_cast(states.size()) == num_hidden_layers_ * 4 + 1); int32_t batch_size = states[0].GetTensorTypeAndShapeInfo().GetShape()[0]; @@ -272,7 +267,7 @@ OnlineEbranchformerTransducerModel::UnStackStates( ans.resize(batch_size); for (int32_t i = 0; i != num_hidden_layers_; ++i) { - { // cached_key + { // cached_key auto v = Unbind(allocator, &states[i * 4], /* axis */ 0); assert(static_cast(v.size()) == batch_size); @@ -280,7 +275,7 @@ OnlineEbranchformerTransducerModel::UnStackStates( ans[n].push_back(std::move(v[n])); } } - { // cached_value + { // cached_value auto v = Unbind(allocator, &states[i * 4 + 1], 0); assert(static_cast(v.size()) == batch_size); @@ -288,7 +283,7 @@ OnlineEbranchformerTransducerModel::UnStackStates( ans[n].push_back(std::move(v[n])); } } - { // cached_conv + { // cached_conv auto v = Unbind(allocator, &states[i * 4 + 2], 0); assert(static_cast(v.size()) == batch_size); @@ -296,7 +291,7 @@ OnlineEbranchformerTransducerModel::UnStackStates( ans[n].push_back(std::move(v[n])); } } - { // cached_conv_fusion + { // cached_conv_fusion auto v = Unbind(allocator, &states[i * 4 + 3], 0); assert(static_cast(v.size()) == batch_size); @@ -306,7 +301,7 @@ OnlineEbranchformerTransducerModel::UnStackStates( } } - { // processed_lens + { // processed_lens auto v = Unbind(allocator, &states.back(), 0); assert(static_cast(v.size()) == batch_size); @@ -318,7 +313,6 @@ OnlineEbranchformerTransducerModel::UnStackStates( return ans; } - std::vector OnlineEbranchformerTransducerModel::GetEncoderInitStates() { std::vector ans; @@ -332,40 +326,37 @@ OnlineEbranchformerTransducerModel::GetEncoderInitStates() { int32_t channels_conv_fusion = 2 * hidden_size_; for (int32_t i = 0; i != num_hidden_layers_; ++i) { - { // cached_key_{i} + { // cached_key_{i} std::array s{1, num_heads_, left_context_len_, head_dim_}; - auto v = - Ort::Value::CreateTensor(allocator_, s.data(), s.size()); + auto v = Ort::Value::CreateTensor(allocator_, s.data(), s.size()); Fill(&v, 0); ans.push_back(std::move(v)); } - { // cahced_value_{i} + { // cahced_value_{i} std::array s{1, num_heads_, left_context_len_, head_dim_}; - auto v = - Ort::Value::CreateTensor(allocator_, s.data(), s.size()); + auto v = Ort::Value::CreateTensor(allocator_, s.data(), s.size()); Fill(&v, 0); ans.push_back(std::move(v)); } - { // cached_conv_{i} + { // cached_conv_{i} std::array s{1, channels_conv, left_context_conv}; - auto v = - Ort::Value::CreateTensor(allocator_, s.data(), s.size()); + auto v = Ort::Value::CreateTensor(allocator_, s.data(), s.size()); Fill(&v, 0); ans.push_back(std::move(v)); } - { // cached_conv_fusion_{i} - std::array s{1, channels_conv_fusion, left_context_conv_fusion}; - auto v = - Ort::Value::CreateTensor(allocator_, s.data(), s.size()); + { // cached_conv_fusion_{i} + std::array s{1, channels_conv_fusion, + left_context_conv_fusion}; + auto v = Ort::Value::CreateTensor(allocator_, s.data(), s.size()); Fill(&v, 0); ans.push_back(std::move(v)); } } // num_hidden_layers_ - { // processed_lens + { // processed_lens std::array s{1}; auto v = Ort::Value::CreateTensor(allocator_, s.data(), s.size()); Fill(&v, 0); @@ -375,11 +366,10 @@ OnlineEbranchformerTransducerModel::GetEncoderInitStates() { return ans; } - std::pair> -OnlineEbranchformerTransducerModel::RunEncoder(Ort::Value features, - std::vector states, - Ort::Value /* processed_frames */) { +OnlineEbranchformerTransducerModel::RunEncoder( + Ort::Value features, std::vector states, + Ort::Value /* processed_frames */) { std::vector encoder_inputs; encoder_inputs.reserve(1 + states.size()); @@ -402,7 +392,6 @@ OnlineEbranchformerTransducerModel::RunEncoder(Ort::Value features, return {std::move(encoder_out[0]), std::move(next_states)}; } - Ort::Value OnlineEbranchformerTransducerModel::RunDecoder( Ort::Value decoder_input) { auto decoder_out = decoder_sess_->Run( @@ -411,9 +400,8 @@ Ort::Value OnlineEbranchformerTransducerModel::RunDecoder( return std::move(decoder_out[0]); } - -Ort::Value OnlineEbranchformerTransducerModel::RunJoiner(Ort::Value encoder_out, - Ort::Value decoder_out) { +Ort::Value OnlineEbranchformerTransducerModel::RunJoiner( + Ort::Value encoder_out, Ort::Value decoder_out) { std::array joiner_input = {std::move(encoder_out), std::move(decoder_out)}; auto logit = @@ -424,7 +412,6 @@ Ort::Value OnlineEbranchformerTransducerModel::RunJoiner(Ort::Value encoder_out, return std::move(logit[0]); } - #if __ANDROID_API__ >= 9 template OnlineEbranchformerTransducerModel::OnlineEbranchformerTransducerModel( AAssetManager *mgr, const OnlineModelConfig &config); diff --git a/sherpa-onnx/csrc/online-ebranchformer-transducer-model.h b/sherpa-onnx/csrc/online-ebranchformer-transducer-model.h index 4329c9f1..f1438cb1 100644 --- a/sherpa-onnx/csrc/online-ebranchformer-transducer-model.h +++ b/sherpa-onnx/csrc/online-ebranchformer-transducer-model.h @@ -22,7 +22,7 @@ class OnlineEbranchformerTransducerModel : public OnlineTransducerModel { template OnlineEbranchformerTransducerModel(Manager *mgr, - const OnlineModelConfig &config); + const OnlineModelConfig &config); std::vector StackStates( const std::vector> &states) const override; diff --git a/sherpa-onnx/csrc/sherpa-onnx-offline.cc b/sherpa-onnx/csrc/sherpa-onnx-offline.cc index 022f7569..5509a861 100644 --- a/sherpa-onnx/csrc/sherpa-onnx-offline.cc +++ b/sherpa-onnx/csrc/sherpa-onnx-offline.cc @@ -131,10 +131,10 @@ for a list of pre-trained models to download. std::vector ss_pointers; float duration = 0; for (int32_t i = 1; i <= po.NumArgs(); ++i) { - const std::string wav_filename = po.GetArg(i); + std::string wav_filename = po.GetArg(i); int32_t sampling_rate = -1; bool is_ok = false; - const std::vector samples = + std::vector samples = sherpa_onnx::ReadWave(wav_filename, &sampling_rate, &is_ok); if (!is_ok) { fprintf(stderr, "Failed to read '%s'\n", wav_filename.c_str()); diff --git a/sherpa-onnx/csrc/sherpa-onnx-vad-with-offline-asr.cc b/sherpa-onnx/csrc/sherpa-onnx-vad-with-offline-asr.cc new file mode 100644 index 00000000..2c806c92 --- /dev/null +++ b/sherpa-onnx/csrc/sherpa-onnx-vad-with-offline-asr.cc @@ -0,0 +1,238 @@ +// sherpa-onnx/csrc/sherpa-onnx-vad-with-offline-asr.cc +// +// Copyright (c) 2025 Xiaomi Corporation + +#include + +#include // NOLINT +#include +#include + +#include "sherpa-onnx/csrc/offline-recognizer.h" +#include "sherpa-onnx/csrc/parse-options.h" +#include "sherpa-onnx/csrc/resample.h" +#include "sherpa-onnx/csrc/voice-activity-detector.h" +#include "sherpa-onnx/csrc/wave-reader.h" + +int main(int32_t argc, char *argv[]) { + const char *kUsageMessage = R"usage( +Speech recognition using VAD + non-streaming models with sherpa-onnx. + +Usage: + +Note you can download silero_vad.onnx using + +wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/silero_vad.onnx + +(0) FireRedAsr + +See https://k2-fsa.github.io/sherpa/onnx/FireRedAsr/pretrained.html + + ./bin/sherpa-onnx-vad-with-offline-asr \ + --tokens=./sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16/tokens.txt \ + --fire-red-asr-encoder=./sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16/encoder.int8.onnx \ + --fire-red-asr-decoder=./sherpa-onnx-fire-red-asr-large-zh_en-2025-02-16/decoder.int8.onnx \ + --num-threads=1 \ + --silero-vad-model=/path/to/silero_vad.onnx \ + /path/to/foo.wav + +(1) Transducer from icefall + +See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-transducer/index.html + + ./bin/sherpa-onnx-vad-with-offline-asr \ + --silero-vad-model=/path/to/silero_vad.onnx \ + --tokens=/path/to/tokens.txt \ + --encoder=/path/to/encoder.onnx \ + --decoder=/path/to/decoder.onnx \ + --joiner=/path/to/joiner.onnx \ + --num-threads=1 \ + --decoding-method=greedy_search \ + /path/to/foo.wav + + +(2) Paraformer from FunASR + +See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-paraformer/index.html + + ./bin/sherpa-onnx-vad-with-offline-asr \ + --silero-vad-model=/path/to/silero_vad.onnx \ + --tokens=/path/to/tokens.txt \ + --paraformer=/path/to/model.onnx \ + --num-threads=1 \ + --decoding-method=greedy_search \ + /path/to/foo.wav + +(3) Moonshine models + +See https://k2-fsa.github.io/sherpa/onnx/moonshine/index.html + + ./bin/sherpa-onnx-vad-with-offline-asr \ + --silero-vad-model=/path/to/silero_vad.onnx \ + --moonshine-preprocessor=/Users/fangjun/open-source/sherpa-onnx/scripts/moonshine/preprocess.onnx \ + --moonshine-encoder=/Users/fangjun/open-source/sherpa-onnx/scripts/moonshine/encode.int8.onnx \ + --moonshine-uncached-decoder=/Users/fangjun/open-source/sherpa-onnx/scripts/moonshine/uncached_decode.int8.onnx \ + --moonshine-cached-decoder=/Users/fangjun/open-source/sherpa-onnx/scripts/moonshine/cached_decode.int8.onnx \ + --tokens=/Users/fangjun/open-source/sherpa-onnx/scripts/moonshine/tokens.txt \ + --num-threads=1 \ + /path/to/foo.wav + +(4) Whisper models + +See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/whisper/tiny.en.html + + ./bin/sherpa-onnx-vad-with-offline-asr \ + --silero-vad-model=/path/to/silero_vad.onnx \ + --whisper-encoder=./sherpa-onnx-whisper-base.en/base.en-encoder.int8.onnx \ + --whisper-decoder=./sherpa-onnx-whisper-base.en/base.en-decoder.int8.onnx \ + --tokens=./sherpa-onnx-whisper-base.en/base.en-tokens.txt \ + --num-threads=1 \ + /path/to/foo.wav + +(5) NeMo CTC models + +See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-ctc/index.html + + ./bin/sherpa-onnx-vad-with-offline-asr \ + --silero-vad-model=/path/to/silero_vad.onnx \ + --tokens=./sherpa-onnx-nemo-ctc-en-conformer-medium/tokens.txt \ + --nemo-ctc-model=./sherpa-onnx-nemo-ctc-en-conformer-medium/model.onnx \ + --num-threads=2 \ + --decoding-method=greedy_search \ + --debug=false \ + ./sherpa-onnx-nemo-ctc-en-conformer-medium/test_wavs/0.wav + +(6) TDNN CTC model for the yesno recipe from icefall + +See https://k2-fsa.github.io/sherpa/onnx/pretrained_models/offline-ctc/yesno/index.html + + ./bin/sherpa-onnx-vad-with-offline-asr \ + --silero-vad-model=/path/to/silero_vad.onnx \ + --sample-rate=8000 \ + --feat-dim=23 \ + --tokens=./sherpa-onnx-tdnn-yesno/tokens.txt \ + --tdnn-model=./sherpa-onnx-tdnn-yesno/model-epoch-14-avg-2.onnx \ + ./sherpa-onnx-tdnn-yesno/test_wavs/0_0_0_1_0_0_0_1.wav + +The input wav should be of single channel, 16-bit PCM encoded wave file; its +sampling rate can be arbitrary and does not need to be 16kHz. + +Please refer to +https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html +for a list of pre-trained models to download. +)usage"; + + sherpa_onnx::ParseOptions po(kUsageMessage); + sherpa_onnx::OfflineRecognizerConfig asr_config; + asr_config.Register(&po); + + sherpa_onnx::VadModelConfig vad_config; + vad_config.Register(&po); + + po.Read(argc, argv); + if (po.NumArgs() != 1) { + fprintf(stderr, "Error: Please provide at only 1 wave file. Given: %d\n\n", + po.NumArgs()); + po.PrintUsage(); + exit(EXIT_FAILURE); + } + + fprintf(stderr, "%s\n", vad_config.ToString().c_str()); + fprintf(stderr, "%s\n", asr_config.ToString().c_str()); + + if (!vad_config.Validate()) { + fprintf(stderr, "Errors in vad_config!\n"); + return -1; + } + + if (!asr_config.Validate()) { + fprintf(stderr, "Errors in ASR config!\n"); + return -1; + } + + fprintf(stderr, "Creating recognizer ...\n"); + sherpa_onnx::OfflineRecognizer recognizer(asr_config); + fprintf(stderr, "Recognizer created!\n"); + + auto vad = std::make_unique(vad_config); + + fprintf(stderr, "Started\n"); + const auto begin = std::chrono::steady_clock::now(); + + std::string wave_filename = po.GetArg(1); + fprintf(stderr, "Reading: %s\n", wave_filename.c_str()); + int32_t sampling_rate = -1; + bool is_ok = false; + auto samples = sherpa_onnx::ReadWave(wave_filename, &sampling_rate, &is_ok); + if (!is_ok) { + fprintf(stderr, "Failed to read '%s'\n", wave_filename.c_str()); + return -1; + } + + if (sampling_rate != 16000) { + fprintf(stderr, "Resampling from %d Hz to 16000 Hz", sampling_rate); + float min_freq = std::min(sampling_rate, 16000); + float lowpass_cutoff = 0.99 * 0.5 * min_freq; + + int32_t lowpass_filter_width = 6; + auto resampler = std::make_unique( + sampling_rate, 16000, lowpass_cutoff, lowpass_filter_width); + std::vector out_samples; + resampler->Resample(samples.data(), samples.size(), true, &out_samples); + samples = std::move(out_samples); + fprintf(stderr, "Resampling done\n"); + } + + fprintf(stderr, "Started!\n"); + int32_t window_size = vad_config.silero_vad.window_size; + int32_t i = 0; + while (i + window_size < samples.size()) { + vad->AcceptWaveform(samples.data() + i, window_size); + i += window_size; + if (i >= samples.size()) { + vad->Flush(); + } + + while (!vad->Empty()) { + const auto &segment = vad->Front(); + float duration = segment.samples.size() / 16000.; + float start_time = segment.start / 16000.; + float end_time = start_time + duration; + if (duration < 0.1) { + vad->Pop(); + continue; + } + + auto s = recognizer.CreateStream(); + s->AcceptWaveform(16000, segment.samples.data(), segment.samples.size()); + recognizer.DecodeStream(s.get()); + const auto &result = s->GetResult(); + if (!result.text.empty()) { + fprintf(stderr, "%.3f -- %.3f: %s\n", start_time, end_time, + result.text.c_str()); + } + vad->Pop(); + } + } + + const auto end = std::chrono::steady_clock::now(); + + float elapsed_seconds = + std::chrono::duration_cast(end - begin) + .count() / + 1000.; + + fprintf(stderr, "num threads: %d\n", asr_config.model_config.num_threads); + fprintf(stderr, "decoding method: %s\n", asr_config.decoding_method.c_str()); + if (asr_config.decoding_method == "modified_beam_search") { + fprintf(stderr, "max active paths: %d\n", asr_config.max_active_paths); + } + + float duration = samples.size() / 16000.; + fprintf(stderr, "Elapsed seconds: %.3f s\n", elapsed_seconds); + float rtf = elapsed_seconds / duration; + fprintf(stderr, "Real time factor (RTF): %.3f / %.3f = %.3f\n", + elapsed_seconds, duration, rtf); + + return 0; +}