From 028b8f2718cfb9ae1a6cd3d158857b3936defd36 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Sun, 11 May 2025 00:23:32 +0800 Subject: [PATCH] Add C++ example for streaming ASR with SenseVoice. (#2199) --- cxx-api-examples/CMakeLists.txt | 11 + cxx-api-examples/dolphin-ctc-cxx-api.cc | 10 +- cxx-api-examples/fire-red-asr-cxx-api.cc | 10 +- cxx-api-examples/moonshine-cxx-api.cc | 10 +- cxx-api-examples/sense-voice-cxx-api.cc | 10 +- ...e-simulate-streaming-microphone-cxx-api.cc | 282 ++++++++++++++++++ .../sense-voice-with-hr-cxx-api.cc | 10 +- cxx-api-examples/sherpa-display.h | 66 ++++ .../streaming-zipformer-cxx-api.cc | 12 +- .../streaming-zipformer-rtf-cxx-api.cc | 12 +- .../streaming-zipformer-with-hr-cxx-api.cc | 12 +- cxx-api-examples/whisper-cxx-api.cc | 10 +- ...mulate-streaming-sense-voice-microphone.py | 21 +- sherpa-onnx/c-api/cxx-api.cc | 38 +++ sherpa-onnx/c-api/cxx-api.h | 23 ++ sherpa-onnx/csrc/homophone-replacer.cc | 37 ++- 16 files changed, 514 insertions(+), 60 deletions(-) create mode 100644 cxx-api-examples/sense-voice-simulate-streaming-microphone-cxx-api.cc create mode 100644 cxx-api-examples/sherpa-display.h diff --git a/cxx-api-examples/CMakeLists.txt b/cxx-api-examples/CMakeLists.txt index e0dc5e66..4ae59609 100644 --- a/cxx-api-examples/CMakeLists.txt +++ b/cxx-api-examples/CMakeLists.txt @@ -27,6 +27,17 @@ target_link_libraries(moonshine-cxx-api sherpa-onnx-cxx-api) add_executable(sense-voice-cxx-api ./sense-voice-cxx-api.cc) target_link_libraries(sense-voice-cxx-api sherpa-onnx-cxx-api) +if(SHERPA_ONNX_ENABLE_PORTAUDIO) + add_executable(sense-voice-simulate-streaming-microphone-cxx-api + ./sense-voice-simulate-streaming-microphone-cxx-api.cc + ${CMAKE_CURRENT_LIST_DIR}/../sherpa-onnx/csrc/microphone.cc + ) + target_link_libraries(sense-voice-simulate-streaming-microphone-cxx-api + sherpa-onnx-cxx-api + portaudio_static + ) +endif() + add_executable(sense-voice-with-hr-cxx-api ./sense-voice-with-hr-cxx-api.cc) target_link_libraries(sense-voice-with-hr-cxx-api sherpa-onnx-cxx-api) diff --git a/cxx-api-examples/dolphin-ctc-cxx-api.cc b/cxx-api-examples/dolphin-ctc-cxx-api.cc index c219b4a9..62258b71 100644 --- a/cxx-api-examples/dolphin-ctc-cxx-api.cc +++ b/cxx-api-examples/dolphin-ctc-cxx-api.cc @@ -33,8 +33,8 @@ int32_t main() { config.model_config.num_threads = 1; std::cout << "Loading model\n"; - OfflineRecognizer recongizer = OfflineRecognizer::Create(config); - if (!recongizer.Get()) { + OfflineRecognizer recognizer = OfflineRecognizer::Create(config); + if (!recognizer.Get()) { std::cerr << "Please check your config\n"; return -1; } @@ -49,13 +49,13 @@ int32_t main() { std::cout << "Start recognition\n"; const auto begin = std::chrono::steady_clock::now(); - OfflineStream stream = recongizer.CreateStream(); + OfflineStream stream = recognizer.CreateStream(); stream.AcceptWaveform(wave.sample_rate, wave.samples.data(), wave.samples.size()); - recongizer.Decode(&stream); + recognizer.Decode(&stream); - OfflineRecognizerResult result = recongizer.GetResult(&stream); + OfflineRecognizerResult result = recognizer.GetResult(&stream); const auto end = std::chrono::steady_clock::now(); const float elapsed_seconds = diff --git a/cxx-api-examples/fire-red-asr-cxx-api.cc b/cxx-api-examples/fire-red-asr-cxx-api.cc index cd363cee..9d0c8fe1 100644 --- a/cxx-api-examples/fire-red-asr-cxx-api.cc +++ b/cxx-api-examples/fire-red-asr-cxx-api.cc @@ -32,8 +32,8 @@ int32_t main() { config.model_config.num_threads = 1; std::cout << "Loading model\n"; - OfflineRecognizer recongizer = OfflineRecognizer::Create(config); - if (!recongizer.Get()) { + OfflineRecognizer recognizer = OfflineRecognizer::Create(config); + if (!recognizer.Get()) { std::cerr << "Please check your config\n"; return -1; } @@ -50,13 +50,13 @@ int32_t main() { std::cout << "Start recognition\n"; const auto begin = std::chrono::steady_clock::now(); - OfflineStream stream = recongizer.CreateStream(); + OfflineStream stream = recognizer.CreateStream(); stream.AcceptWaveform(wave.sample_rate, wave.samples.data(), wave.samples.size()); - recongizer.Decode(&stream); + recognizer.Decode(&stream); - OfflineRecognizerResult result = recongizer.GetResult(&stream); + OfflineRecognizerResult result = recognizer.GetResult(&stream); const auto end = std::chrono::steady_clock::now(); const float elapsed_seconds = diff --git a/cxx-api-examples/moonshine-cxx-api.cc b/cxx-api-examples/moonshine-cxx-api.cc index c2ce565c..e198a097 100644 --- a/cxx-api-examples/moonshine-cxx-api.cc +++ b/cxx-api-examples/moonshine-cxx-api.cc @@ -36,8 +36,8 @@ int32_t main() { config.model_config.num_threads = 1; std::cout << "Loading model\n"; - OfflineRecognizer recongizer = OfflineRecognizer::Create(config); - if (!recongizer.Get()) { + OfflineRecognizer recognizer = OfflineRecognizer::Create(config); + if (!recognizer.Get()) { std::cerr << "Please check your config\n"; return -1; } @@ -54,13 +54,13 @@ int32_t main() { std::cout << "Start recognition\n"; const auto begin = std::chrono::steady_clock::now(); - OfflineStream stream = recongizer.CreateStream(); + OfflineStream stream = recognizer.CreateStream(); stream.AcceptWaveform(wave.sample_rate, wave.samples.data(), wave.samples.size()); - recongizer.Decode(&stream); + recognizer.Decode(&stream); - OfflineRecognizerResult result = recongizer.GetResult(&stream); + OfflineRecognizerResult result = recognizer.GetResult(&stream); const auto end = std::chrono::steady_clock::now(); const float elapsed_seconds = diff --git a/cxx-api-examples/sense-voice-cxx-api.cc b/cxx-api-examples/sense-voice-cxx-api.cc index ea642b98..7a8a9df3 100644 --- a/cxx-api-examples/sense-voice-cxx-api.cc +++ b/cxx-api-examples/sense-voice-cxx-api.cc @@ -32,8 +32,8 @@ int32_t main() { config.model_config.num_threads = 1; std::cout << "Loading model\n"; - OfflineRecognizer recongizer = OfflineRecognizer::Create(config); - if (!recongizer.Get()) { + OfflineRecognizer recognizer = OfflineRecognizer::Create(config); + if (!recognizer.Get()) { std::cerr << "Please check your config\n"; return -1; } @@ -51,13 +51,13 @@ int32_t main() { std::cout << "Start recognition\n"; const auto begin = std::chrono::steady_clock::now(); - OfflineStream stream = recongizer.CreateStream(); + OfflineStream stream = recognizer.CreateStream(); stream.AcceptWaveform(wave.sample_rate, wave.samples.data(), wave.samples.size()); - recongizer.Decode(&stream); + recognizer.Decode(&stream); - OfflineRecognizerResult result = recongizer.GetResult(&stream); + OfflineRecognizerResult result = recognizer.GetResult(&stream); const auto end = std::chrono::steady_clock::now(); const float elapsed_seconds = diff --git a/cxx-api-examples/sense-voice-simulate-streaming-microphone-cxx-api.cc b/cxx-api-examples/sense-voice-simulate-streaming-microphone-cxx-api.cc new file mode 100644 index 00000000..0ace2336 --- /dev/null +++ b/cxx-api-examples/sense-voice-simulate-streaming-microphone-cxx-api.cc @@ -0,0 +1,282 @@ +// cxx-api-examples/sense-voice-simulate-streaming-microphone-cxx-api.cc +// Copyright (c) 2025 Xiaomi Corporation + +// +// This file demonstrates how to use sense voice with sherpa-onnx's C++ API +// for streaming speech recognition from a microphone. +// +// clang-format off +// +// wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/silero_vad.onnx +// +// wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17.tar.bz2 +// tar xvf sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17.tar.bz2 +// rm sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17.tar.bz2 +// +// clang-format on + +#include +#include +#include + +#include // NOLINT +#include // NOLINT +#include +#include // NOLINT +#include +#include + +#include "portaudio.h" // NOLINT +#include "sherpa-display.h" // NOLINT +#include "sherpa-onnx/c-api/cxx-api.h" +#include "sherpa-onnx/csrc/microphone.h" + +std::queue> samples_queue; +std::condition_variable condition_variable; +std::mutex mutex; +bool stop = false; + +static void Handler(int32_t /*sig*/) { + stop = true; + condition_variable.notify_one(); + fprintf(stderr, "\nCaught Ctrl + C. Exiting...\n"); +} + +static int32_t RecordCallback(const void *input_buffer, + void * /*output_buffer*/, + unsigned long frames_per_buffer, // NOLINT + const PaStreamCallbackTimeInfo * /*time_info*/, + PaStreamCallbackFlags /*status_flags*/, + void * /*user_data*/) { + std::lock_guard lock(mutex); + samples_queue.emplace( + reinterpret_cast(input_buffer), + reinterpret_cast(input_buffer) + frames_per_buffer); + condition_variable.notify_one(); + + return stop ? paComplete : paContinue; +} + +static sherpa_onnx::cxx::VoiceActivityDetector CreateVad() { + using namespace sherpa_onnx::cxx; // NOLINT + VadModelConfig config; + config.silero_vad.model = "./silero_vad.onnx"; + config.silero_vad.threshold = 0.5; + config.silero_vad.min_silence_duration = 0.1; + config.silero_vad.min_speech_duration = 0.25; + config.silero_vad.max_speech_duration = 8; + config.sample_rate = 16000; + config.debug = false; + + VoiceActivityDetector vad = VoiceActivityDetector::Create(config, 20); + if (!vad.Get()) { + std::cerr << "Failed to create VAD. Please check your config\n"; + exit(-1); + } + + return vad; +} + +static sherpa_onnx::cxx::OfflineRecognizer CreateOfflineRecognizer() { + using namespace sherpa_onnx::cxx; // NOLINT + OfflineRecognizerConfig config; + + config.model_config.sense_voice.model = + "./sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17/model.int8.onnx"; + config.model_config.sense_voice.use_itn = false; + config.model_config.sense_voice.language = "auto"; + config.model_config.tokens = + "./sherpa-onnx-sense-voice-zh-en-ja-ko-yue-2024-07-17/tokens.txt"; + + config.model_config.num_threads = 2; + config.model_config.debug = false; + + std::cout << "Loading model\n"; + OfflineRecognizer recognizer = OfflineRecognizer::Create(config); + if (!recognizer.Get()) { + std::cerr << "Please check your config\n"; + exit(-1); + } + std::cout << "Loading model done\n"; + return recognizer; +} + +int32_t main() { + signal(SIGINT, Handler); + + using namespace sherpa_onnx::cxx; // NOLINT + + auto vad = CreateVad(); + auto recognizer = CreateOfflineRecognizer(); + + sherpa_onnx::Microphone mic; + + PaDeviceIndex num_devices = Pa_GetDeviceCount(); + std::cout << "Num devices: " << num_devices << "\n"; + if (num_devices == 0) { + std::cerr << " If you are using Linux, please try " + "./build/bin/sense-voice-simulate-streaming-alsa-cxx-api\n"; + return -1; + } + + int32_t device_index = Pa_GetDefaultInputDevice(); + + const char *pDeviceIndex = std::getenv("SHERPA_ONNX_MIC_DEVICE"); + if (pDeviceIndex) { + fprintf(stderr, "Use specified device: %s\n", pDeviceIndex); + device_index = atoi(pDeviceIndex); + } + + for (int32_t i = 0; i != num_devices; ++i) { + const PaDeviceInfo *info = Pa_GetDeviceInfo(i); + fprintf(stderr, " %s %d %s\n", (i == device_index) ? "*" : " ", i, + info->name); + } + + PaStreamParameters param; + param.device = device_index; + + fprintf(stderr, "Use device: %d\n", param.device); + + const PaDeviceInfo *info = Pa_GetDeviceInfo(param.device); + fprintf(stderr, " Name: %s\n", info->name); + fprintf(stderr, " Max input channels: %d\n", info->maxInputChannels); + + param.channelCount = 1; + param.sampleFormat = paFloat32; + + param.suggestedLatency = info->defaultLowInputLatency; + param.hostApiSpecificStreamInfo = nullptr; + float mic_sample_rate = 16000; + const char *sample_rate_str = std::getenv("SHERPA_ONNX_MIC_SAMPLE_RATE"); + if (sample_rate_str) { + fprintf(stderr, "Use sample rate %f for mic\n", mic_sample_rate); + mic_sample_rate = atof(sample_rate_str); + } + float sample_rate = 16000; + LinearResampler resampler; + if (mic_sample_rate != sample_rate) { + float min_freq = std::min(mic_sample_rate, sample_rate); + float lowpass_cutoff = 0.99 * 0.5 * min_freq; + + int32_t lowpass_filter_width = 6; + resampler = LinearResampler::Create(mic_sample_rate, sample_rate, + lowpass_cutoff, lowpass_filter_width); + } + + PaStream *stream; + PaError err = + Pa_OpenStream(&stream, ¶m, nullptr, /* &outputParameters, */ + mic_sample_rate, + 0, // frames per buffer + paClipOff, // we won't output out of range samples + // so don't bother clipping them + RecordCallback, // RecordCallback is run in a separate + // thread created by portaudio + nullptr); + if (err != paNoError) { + fprintf(stderr, "portaudio error: %s\n", Pa_GetErrorText(err)); + exit(EXIT_FAILURE); + } + + err = Pa_StartStream(stream); + if (err != paNoError) { + fprintf(stderr, "portaudio error: %s\n", Pa_GetErrorText(err)); + exit(EXIT_FAILURE); + } + + int32_t window_size = 512; // samples, please don't change + + int32_t offset = 0; + std::vector buffer; + bool speech_started = false; + + auto started_time = std::chrono::steady_clock::now(); + + SherpaDisplay display; + + std::cout << "Started! Please speak\n"; + + while (!stop) { + { + std::unique_lock lock(mutex); + while (samples_queue.empty() && !stop) { + condition_variable.wait(lock); + } + + const auto &s = samples_queue.front(); + if (!resampler.Get()) { + buffer.insert(buffer.end(), s.begin(), s.end()); + } else { + auto resampled = resampler.Resample(s.data(), s.size(), false); + buffer.insert(buffer.end(), resampled.begin(), resampled.end()); + } + + samples_queue.pop(); + } + + for (; offset + window_size < buffer.size(); offset += window_size) { + vad.AcceptWaveform(buffer.data() + offset, window_size); + if (!speech_started && vad.IsDetected()) { + speech_started = true; + started_time = std::chrono::steady_clock::now(); + } + } + if (!speech_started) { + if (buffer.size() > 10 * window_size) { + offset -= buffer.size() - 10 * window_size; + buffer = {buffer.end() - 10 * window_size, buffer.end()}; + } + } + + auto current_time = std::chrono::steady_clock::now(); + const float elapsed_seconds = + std::chrono::duration_cast(current_time - + started_time) + .count() / + 1000.; + + if (speech_started && elapsed_seconds > 0.2) { + OfflineStream stream = recognizer.CreateStream(); + stream.AcceptWaveform(sample_rate, buffer.data(), buffer.size()); + + recognizer.Decode(&stream); + + OfflineRecognizerResult result = recognizer.GetResult(&stream); + display.UpdateText(result.text); + display.Display(); + + started_time = std::chrono::steady_clock::now(); + } + + while (!vad.IsEmpty()) { + auto segment = vad.Front(); + + vad.Pop(); + + OfflineStream stream = recognizer.CreateStream(); + stream.AcceptWaveform(sample_rate, segment.samples.data(), + segment.samples.size()); + + recognizer.Decode(&stream); + + OfflineRecognizerResult result = recognizer.GetResult(&stream); + + display.UpdateText(result.text); + display.FinalizeCurrentSentence(); + display.Display(); + + buffer.clear(); + offset = 0; + speech_started = false; + } + } + + err = Pa_CloseStream(stream); + if (err != paNoError) { + fprintf(stderr, "portaudio error: %s\n", Pa_GetErrorText(err)); + exit(EXIT_FAILURE); + } + + return 0; +} diff --git a/cxx-api-examples/sense-voice-with-hr-cxx-api.cc b/cxx-api-examples/sense-voice-with-hr-cxx-api.cc index 2de8425a..008a9ef6 100644 --- a/cxx-api-examples/sense-voice-with-hr-cxx-api.cc +++ b/cxx-api-examples/sense-voice-with-hr-cxx-api.cc @@ -47,8 +47,8 @@ int32_t main() { config.model_config.num_threads = 1; std::cout << "Loading model\n"; - OfflineRecognizer recongizer = OfflineRecognizer::Create(config); - if (!recongizer.Get()) { + OfflineRecognizer recognizer = OfflineRecognizer::Create(config); + if (!recognizer.Get()) { std::cerr << "Please check your config\n"; return -1; } @@ -65,13 +65,13 @@ int32_t main() { std::cout << "Start recognition\n"; const auto begin = std::chrono::steady_clock::now(); - OfflineStream stream = recongizer.CreateStream(); + OfflineStream stream = recognizer.CreateStream(); stream.AcceptWaveform(wave.sample_rate, wave.samples.data(), wave.samples.size()); - recongizer.Decode(&stream); + recognizer.Decode(&stream); - OfflineRecognizerResult result = recongizer.GetResult(&stream); + OfflineRecognizerResult result = recognizer.GetResult(&stream); const auto end = std::chrono::steady_clock::now(); const float elapsed_seconds = diff --git a/cxx-api-examples/sherpa-display.h b/cxx-api-examples/sherpa-display.h new file mode 100644 index 00000000..b0fc4605 --- /dev/null +++ b/cxx-api-examples/sherpa-display.h @@ -0,0 +1,66 @@ +#pragma once + +#include + +#include +#include +#include +#include + +namespace sherpa_onnx::cxx { + +class SherpaDisplay { + public: + void UpdateText(const std::string &text) { current_text_ = text; } + + void FinalizeCurrentSentence() { + if (!current_text_.empty() && current_text_[0] != ' ') { + sentences_.push_back({GetCurrentDateTime(), std::move(current_text_)}); + } + } + + void Display() const { + if (!sentences_.empty() || !current_text_.empty()) { + ClearScreen(); + } + + printf("=== Speech Recognition with Next-gen Kaldi ===\n"); + printf("------------------------------\n"); + if (!sentences_.empty()) { + int32_t i = 1; + for (const auto &p : sentences_) { + printf("[%s] %d. %s\n", p.first.c_str(), i, p.second.c_str()); + i += 1; + } + + printf("------------------------------\n"); + } + + if (!current_text_.empty()) { + printf("Recognizing: %s\n", current_text_.c_str()); + } + } + + private: + static void ClearScreen() { +#ifdef _MSC_VER + system("cls"); +#else + system("clear"); +#endif + } + + static std::string GetCurrentDateTime() { + std::ostringstream os; + auto t = std::time(nullptr); + auto tm = std::localtime(&t); + os << std::put_time(tm, "%Y-%m-%d %H:%M:%S"); + return os.str(); + } + + private: + std::vector> sentences_; + std::string current_text_; +}; + +} // namespace sherpa_onnx::cxx diff --git a/cxx-api-examples/streaming-zipformer-cxx-api.cc b/cxx-api-examples/streaming-zipformer-cxx-api.cc index ac4abc47..f4c6226c 100644 --- a/cxx-api-examples/streaming-zipformer-cxx-api.cc +++ b/cxx-api-examples/streaming-zipformer-cxx-api.cc @@ -44,8 +44,8 @@ int32_t main() { config.model_config.num_threads = 1; std::cout << "Loading model\n"; - OnlineRecognizer recongizer = OnlineRecognizer::Create(config); - if (!recongizer.Get()) { + OnlineRecognizer recognizer = OnlineRecognizer::Create(config); + if (!recognizer.Get()) { std::cerr << "Please check your config\n"; return -1; } @@ -63,16 +63,16 @@ int32_t main() { std::cout << "Start recognition\n"; const auto begin = std::chrono::steady_clock::now(); - OnlineStream stream = recongizer.CreateStream(); + OnlineStream stream = recognizer.CreateStream(); stream.AcceptWaveform(wave.sample_rate, wave.samples.data(), wave.samples.size()); stream.InputFinished(); - while (recongizer.IsReady(&stream)) { - recongizer.Decode(&stream); + while (recognizer.IsReady(&stream)) { + recognizer.Decode(&stream); } - OnlineRecognizerResult result = recongizer.GetResult(&stream); + OnlineRecognizerResult result = recognizer.GetResult(&stream); const auto end = std::chrono::steady_clock::now(); const float elapsed_seconds = diff --git a/cxx-api-examples/streaming-zipformer-rtf-cxx-api.cc b/cxx-api-examples/streaming-zipformer-rtf-cxx-api.cc index 2e74d30b..b9ff1aaf 100644 --- a/cxx-api-examples/streaming-zipformer-rtf-cxx-api.cc +++ b/cxx-api-examples/streaming-zipformer-rtf-cxx-api.cc @@ -73,8 +73,8 @@ int32_t main(int argc, char *argv[]) { config.model_config.provider = use_gpu ? "cuda" : "cpu"; std::cout << "Loading model\n"; - OnlineRecognizer recongizer = OnlineRecognizer::Create(config); - if (!recongizer.Get()) { + OnlineRecognizer recognizer = OnlineRecognizer::Create(config); + if (!recognizer.Get()) { std::cerr << "Please check your config\n"; return -1; } @@ -95,16 +95,16 @@ int32_t main(int argc, char *argv[]) { for (int32_t i = 0; i < num_runs; ++i) { const auto begin = std::chrono::steady_clock::now(); - OnlineStream stream = recongizer.CreateStream(); + OnlineStream stream = recognizer.CreateStream(); stream.AcceptWaveform(wave.sample_rate, wave.samples.data(), wave.samples.size()); stream.InputFinished(); - while (recongizer.IsReady(&stream)) { - recongizer.Decode(&stream); + while (recognizer.IsReady(&stream)) { + recognizer.Decode(&stream); } - result = recongizer.GetResult(&stream); + result = recognizer.GetResult(&stream); auto end = std::chrono::steady_clock::now(); float elapsed_seconds = diff --git a/cxx-api-examples/streaming-zipformer-with-hr-cxx-api.cc b/cxx-api-examples/streaming-zipformer-with-hr-cxx-api.cc index e9e6ed4e..ad97a43d 100644 --- a/cxx-api-examples/streaming-zipformer-with-hr-cxx-api.cc +++ b/cxx-api-examples/streaming-zipformer-with-hr-cxx-api.cc @@ -59,8 +59,8 @@ int32_t main() { config.hr.rule_fsts = "./replace.fst"; std::cout << "Loading model\n"; - OnlineRecognizer recongizer = OnlineRecognizer::Create(config); - if (!recongizer.Get()) { + OnlineRecognizer recognizer = OnlineRecognizer::Create(config); + if (!recognizer.Get()) { std::cerr << "Please check your config\n"; return -1; } @@ -76,16 +76,16 @@ int32_t main() { std::cout << "Start recognition\n"; const auto begin = std::chrono::steady_clock::now(); - OnlineStream stream = recongizer.CreateStream(); + OnlineStream stream = recognizer.CreateStream(); stream.AcceptWaveform(wave.sample_rate, wave.samples.data(), wave.samples.size()); stream.InputFinished(); - while (recongizer.IsReady(&stream)) { - recongizer.Decode(&stream); + while (recognizer.IsReady(&stream)) { + recognizer.Decode(&stream); } - OnlineRecognizerResult result = recongizer.GetResult(&stream); + OnlineRecognizerResult result = recognizer.GetResult(&stream); const auto end = std::chrono::steady_clock::now(); const float elapsed_seconds = diff --git a/cxx-api-examples/whisper-cxx-api.cc b/cxx-api-examples/whisper-cxx-api.cc index 348d115b..a805e788 100644 --- a/cxx-api-examples/whisper-cxx-api.cc +++ b/cxx-api-examples/whisper-cxx-api.cc @@ -32,8 +32,8 @@ int32_t main() { config.model_config.num_threads = 1; std::cout << "Loading model\n"; - OfflineRecognizer recongizer = OfflineRecognizer::Create(config); - if (!recongizer.Get()) { + OfflineRecognizer recognizer = OfflineRecognizer::Create(config); + if (!recognizer.Get()) { std::cerr << "Please check your config\n"; return -1; } @@ -49,13 +49,13 @@ int32_t main() { std::cout << "Start recognition\n"; const auto begin = std::chrono::steady_clock::now(); - OfflineStream stream = recongizer.CreateStream(); + OfflineStream stream = recognizer.CreateStream(); stream.AcceptWaveform(wave.sample_rate, wave.samples.data(), wave.samples.size()); - recongizer.Decode(&stream); + recognizer.Decode(&stream); - OfflineRecognizerResult result = recongizer.GetResult(&stream); + OfflineRecognizerResult result = recognizer.GetResult(&stream); const auto end = std::chrono::steady_clock::now(); const float elapsed_seconds = diff --git a/python-api-examples/simulate-streaming-sense-voice-microphone.py b/python-api-examples/simulate-streaming-sense-voice-microphone.py index d9a41ec4..a0294ab4 100755 --- a/python-api-examples/simulate-streaming-sense-voice-microphone.py +++ b/python-api-examples/simulate-streaming-sense-voice-microphone.py @@ -74,7 +74,7 @@ def get_args(): parser.add_argument( "--num-threads", type=int, - default=1, + default=2, help="Number of threads for neural network computation", ) @@ -164,7 +164,13 @@ def main(): config = sherpa_onnx.VadModelConfig() config.silero_vad.model = args.silero_vad_model - config.silero_vad.min_silence_duration = 0.25 + config.silero_vad.threshold = 0.5 + config.silero_vad.min_silence_duration = 0.1 # seconds + config.silero_vad.min_speech_duration = 0.25 # seconds + # If the current segment is larger than this value, then it increases + # the threshold to 0.9 internally. After detecting this segment, + # it resets the threshold to its original value. + config.silero_vad.max_speech_duration = 8 # seconds config.sample_rate = sample_rate window_size = config.silero_vad.window_size @@ -184,20 +190,22 @@ def main(): started = False started_time = None + offset = 0 while not killed: samples = samples_queue.get() # a blocking read buffer = np.concatenate([buffer, samples]) - offset = 0 - while offset + window_size < samples.shape[0]: - vad.accept_waveform(samples[offset : offset + window_size]) + while offset + window_size < len(buffer): + vad.accept_waveform(buffer[offset : offset + window_size]) if not started and vad.is_speech_detected(): started = True started_time = time.time() offset += window_size if not started: - buffer = buffer[-10 * window_size :] + if len(buffer) > 10 * window_size: + offset -= len(buffer) - 10 * window_size + buffer = buffer[-10 * window_size :] if started and time.time() - started_time > 0.2: stream = recognizer.create_stream() @@ -223,6 +231,7 @@ def main(): display.update_text(text) buffer = [] + offset = 0 started = False started_time = None diff --git a/sherpa-onnx/c-api/cxx-api.cc b/sherpa-onnx/c-api/cxx-api.cc index dea70579..9e84f3b4 100644 --- a/sherpa-onnx/c-api/cxx-api.cc +++ b/sherpa-onnx/c-api/cxx-api.cc @@ -678,4 +678,42 @@ void VoiceActivityDetector::Flush() const { SherpaOnnxVoiceActivityDetectorFlush(p_); } +LinearResampler LinearResampler::Create(int32_t samp_rate_in_hz, + int32_t samp_rate_out_hz, + float filter_cutoff_hz, + int32_t num_zeros) { + auto p = SherpaOnnxCreateLinearResampler(samp_rate_in_hz, samp_rate_out_hz, + filter_cutoff_hz, num_zeros); + return LinearResampler(p); +} + +LinearResampler::LinearResampler(const SherpaOnnxLinearResampler *p) + : MoveOnly(p) {} + +void LinearResampler::Destroy(const SherpaOnnxLinearResampler *p) const { + SherpaOnnxDestroyLinearResampler(p); +} + +void LinearResampler::Reset() const { SherpaOnnxLinearResamplerReset(p_); } + +std::vector LinearResampler::Resample(const float *input, + int32_t input_dim, + bool flush) const { + auto out = SherpaOnnxLinearResamplerResample(p_, input, input_dim, flush); + + std::vector ans{out->samples, out->samples + out->n}; + + SherpaOnnxLinearResamplerResampleFree(out); + + return ans; +} + +int32_t LinearResampler::GetInputSamplingRate() const { + return SherpaOnnxLinearResamplerResampleGetInputSampleRate(p_); +} + +int32_t LinearResampler::GetOutputSamplingRate() const { + return SherpaOnnxLinearResamplerResampleGetOutputSampleRate(p_); +} + } // namespace sherpa_onnx::cxx diff --git a/sherpa-onnx/c-api/cxx-api.h b/sherpa-onnx/c-api/cxx-api.h index a8fd6552..28ea4ee2 100644 --- a/sherpa-onnx/c-api/cxx-api.h +++ b/sherpa-onnx/c-api/cxx-api.h @@ -111,6 +111,7 @@ SHERPA_ONNX_API bool WriteWave(const std::string &filename, const Wave &wave); template class SHERPA_ONNX_API MoveOnly { public: + MoveOnly() = default; explicit MoveOnly(const T *p) : p_(p) {} ~MoveOnly() { Destroy(); } @@ -591,6 +592,28 @@ class SHERPA_ONNX_API VoiceActivityDetector explicit VoiceActivityDetector(const SherpaOnnxVoiceActivityDetector *p); }; +class SHERPA_ONNX_API LinearResampler + : public MoveOnly { + public: + LinearResampler() = default; + static LinearResampler Create(int32_t samp_rate_in_hz, + int32_t samp_rate_out_hz, + float filter_cutoff_hz, int32_t num_zeros); + + void Destroy(const SherpaOnnxLinearResampler *p) const; + + void Reset() const; + + std::vector Resample(const float *input, int32_t input_dim, + bool flush) const; + + int32_t GetInputSamplingRate() const; + int32_t GetOutputSamplingRate() const; + + private: + explicit LinearResampler(const SherpaOnnxLinearResampler *p); +}; + } // namespace sherpa_onnx::cxx #endif // SHERPA_ONNX_C_API_CXX_API_H_ diff --git a/sherpa-onnx/csrc/homophone-replacer.cc b/sherpa-onnx/csrc/homophone-replacer.cc index 69696a94..ca36b783 100644 --- a/sherpa-onnx/csrc/homophone-replacer.cc +++ b/sherpa-onnx/csrc/homophone-replacer.cc @@ -166,20 +166,32 @@ class HomophoneReplacer::Impl { } // convert words to pronunciations - std::vector pronunciations; + std::vector current_words; + std::vector current_pronunciations; for (const auto &w : words) { + if (w.size() < 3 || + reinterpret_cast(w.data())[0] < 128) { + if (!current_words.empty()) { + ans += ApplyImpl(current_words, current_pronunciations); + current_words.clear(); + current_pronunciations.clear(); + } + ans += w; + continue; + } + auto p = ConvertWordToPronunciation(w); if (config_.debug) { SHERPA_ONNX_LOGE("%s %s", w.c_str(), p.c_str()); } - pronunciations.push_back(std::move(p)); + + current_words.push_back(w); + current_pronunciations.push_back(std::move(p)); } - for (const auto &r : replacer_list_) { - ans = r->Normalize(words, pronunciations); - // TODO(fangjun): We support only 1 rule fst at present. - break; + if (!current_words.empty()) { + ans += ApplyImpl(current_words, current_pronunciations); } if (config_.debug) { @@ -190,6 +202,16 @@ class HomophoneReplacer::Impl { } private: + std::string ApplyImpl(const std::vector &words, + const std::vector &pronunciations) const { + std::string ans; + for (const auto &r : replacer_list_) { + ans = r->Normalize(words, pronunciations); + // TODO(fangjun): We support only 1 rule fst at present. + break; + } + return ans; + } std::string ConvertWordToPronunciation(const std::string &word) const { if (word2pron_.count(word)) { return word2pron_.at(word); @@ -239,6 +261,9 @@ class HomophoneReplacer::Impl { } while (iss >> p) { + if (p.back() > '4') { + p.push_back('1'); + } pron.append(std::move(p)); }