diff --git a/CMakeLists.txt b/CMakeLists.txt index 2067f132..90b50bb7 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,7 +1,7 @@ cmake_minimum_required(VERSION 3.13 FATAL_ERROR) project(sherpa-onnx) -set(SHERPA_ONNX_VERSION "1.7.14") +set(SHERPA_ONNX_VERSION "1.7.15") # Disable warning about # diff --git a/cmake/cmake_extension.py b/cmake/cmake_extension.py index e4d45bb3..4ff1f8c4 100644 --- a/cmake/cmake_extension.py +++ b/cmake/cmake_extension.py @@ -136,6 +136,7 @@ class BuildExtension(build_ext): binaries += ["sherpa-onnx-online-websocket-server"] binaries += ["sherpa-onnx-offline-websocket-server"] binaries += ["sherpa-onnx-online-websocket-client"] + binaries += ["sherpa-onnx-vad-microphone"] if is_windows(): binaries += ["kaldi-native-fbank-core.dll"] diff --git a/python-api-examples/README.md b/python-api-examples/README.md new file mode 100644 index 00000000..a1e54a36 --- /dev/null +++ b/python-api-examples/README.md @@ -0,0 +1,9 @@ +# File description + +- [./http_server.py](./http_server.py) It defines which files to server. + Files are saved in [./web](./web). +- [non_streaming_server.py](./non_streaming_server.py) WebSocket server for + non-streaming models. +- [vad-remove-non-speech-segments.py](./vad-remove-non-speech-segments.py) It uses + [silero-vad](https://github.com/snakers4/silero-vad) to remove non-speech + segments and concatenate all speech segments into a single one. diff --git a/python-api-examples/vad-remove-non-speech-segments.py b/python-api-examples/vad-remove-non-speech-segments.py new file mode 100755 index 00000000..e55d88b0 --- /dev/null +++ b/python-api-examples/vad-remove-non-speech-segments.py @@ -0,0 +1,126 @@ +#!/usr/bin/env python3 + +""" +This file shows how to remove non-speech segments +and merge all speech segments into a large segment +and save it to a file. + +Usage + +python3 ./vad-remove-non-speech-segments.py \ + --silero-vad-model silero_vad.onnx + +Please visit +https://github.com/snakers4/silero-vad/blob/master/files/silero_vad.onnx +to download silero_vad.onnx + +For instance, + +wget https://github.com/snakers4/silero-vad/raw/master/files/silero_vad.onnx +""" + +import argparse +import sys +import time +from pathlib import Path + +import numpy as np +import sherpa_onnx +import soundfile as sf + +try: + import sounddevice as sd +except ImportError: + print("Please install sounddevice first. You can use") + print() + print(" pip install sounddevice") + print() + print("to install it") + sys.exit(-1) + + +def assert_file_exists(filename: str): + assert Path(filename).is_file(), ( + f"{filename} does not exist!\n" + "Please refer to " + "https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html to download it" + ) + + +def get_args(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--silero-vad-model", + type=str, + required=True, + help="Path to silero_vad.onnx", + ) + + return parser.parse_args() + + +def main(): + devices = sd.query_devices() + if len(devices) == 0: + print("No microphone devices found") + sys.exit(0) + + print(devices) + default_input_device_idx = sd.default.device[0] + print(f'Use default device: {devices[default_input_device_idx]["name"]}') + + args = get_args() + assert_file_exists(args.silero_vad_model) + + sample_rate = 16000 + samples_per_read = int(0.1 * sample_rate) # 0.1 second = 100 ms + + config = sherpa_onnx.VadModelConfig() + config.silero_vad.model = args.silero_vad_model + config.sample_rate = sample_rate + + window_size = config.silero_vad.window_size + + buffer = [] + vad = sherpa_onnx.VoiceActivityDetector(config, buffer_size_in_seconds=30) + + all_samples = [] + + print("Started! Please speak") + + try: + with sd.InputStream(channels=1, dtype="float32", samplerate=sample_rate) as s: + while True: + samples, _ = s.read(samples_per_read) # a blocking read + samples = samples.reshape(-1) + buffer = np.concatenate([buffer, samples]) + + all_samples = np.concatenate([all_samples, samples]) + + while len(buffer) > window_size: + vad.accept_waveform(buffer[:window_size]) + buffer = buffer[window_size:] + except KeyboardInterrupt: + print("\nCaught Ctrl + C. Saving & Exiting") + + speech_samples = [] + while not vad.empty(): + speech_samples.extend(vad.front.samples) + vad.pop() + + speech_samples = np.array(speech_samples, dtype=np.float32) + + filename_for_speech = time.strftime("%Y%m%d-%H%M%S-speech.wav") + sf.write(filename_for_speech, speech_samples, samplerate=sample_rate) + + filename_for_all = time.strftime("%Y%m%d-%H%M%S-all.wav") + sf.write(filename_for_all, all_samples, samplerate=sample_rate) + + print(f"Saved to {filename_for_speech} and {filename_for_all}") + + +if __name__ == "__main__": + main() diff --git a/setup.py b/setup.py index 97bf3d86..d6084a19 100644 --- a/setup.py +++ b/setup.py @@ -56,6 +56,7 @@ def get_binaries_to_install(): binaries += ["sherpa-onnx-online-websocket-server"] binaries += ["sherpa-onnx-offline-websocket-server"] binaries += ["sherpa-onnx-online-websocket-client"] + binaries += ["sherpa-onnx-vad-microphone"] if is_windows(): binaries += ["kaldi-native-fbank-core.dll"] binaries += ["sherpa-onnx-c-api.dll"] @@ -95,8 +96,8 @@ setuptools.setup( "Topic :: Scientific/Engineering :: Artificial Intelligence", ], entry_points={ - 'console_scripts': [ - 'sherpa-onnx-cli=sherpa_onnx.cli:cli', + "console_scripts": [ + "sherpa-onnx-cli=sherpa_onnx.cli:cli", ], }, license="Apache licensed, as found in the LICENSE file", diff --git a/sherpa-onnx/csrc/CMakeLists.txt b/sherpa-onnx/csrc/CMakeLists.txt index 8753af62..14136f86 100644 --- a/sherpa-onnx/csrc/CMakeLists.txt +++ b/sherpa-onnx/csrc/CMakeLists.txt @@ -13,6 +13,7 @@ endif() set(sources base64-decode.cc cat.cc + circular-buffer.cc context-graph.cc endpoint.cc features.cc @@ -66,6 +67,8 @@ set(sources provider.cc resample.cc session.cc + silero-vad-model-config.cc + silero-vad-model.cc slice.cc stack.cc symbol-table.cc @@ -73,6 +76,9 @@ set(sources transpose.cc unbind.cc utils.cc + vad-model-config.cc + vad-model.cc + voice-activity-detector.cc wave-reader.cc ) @@ -172,32 +178,42 @@ if(SHERPA_ONNX_ENABLE_PORTAUDIO) microphone.cc ) + add_executable(sherpa-onnx-vad-microphone + sherpa-onnx-vad-microphone.cc + microphone.cc + ) + if(BUILD_SHARED_LIBS) set(PA_LIB portaudio) else() set(PA_LIB portaudio_static) endif() - target_link_libraries(sherpa-onnx-microphone ${PA_LIB} sherpa-onnx-core) - target_link_libraries(sherpa-onnx-microphone-offline ${PA_LIB} sherpa-onnx-core) + set(exes + sherpa-onnx-microphone + sherpa-onnx-microphone-offline + sherpa-onnx-vad-microphone + ) + foreach(exe IN LISTS exes) + target_link_libraries(${exe} ${PA_LIB} sherpa-onnx-core) + endforeach() if(NOT WIN32) - target_link_libraries(sherpa-onnx-microphone "-Wl,-rpath,${SHERPA_ONNX_RPATH_ORIGIN}/../lib") - target_link_libraries(sherpa-onnx-microphone "-Wl,-rpath,${SHERPA_ONNX_RPATH_ORIGIN}/../../../sherpa_onnx/lib") - - target_link_libraries(sherpa-onnx-microphone-offline "-Wl,-rpath,${SHERPA_ONNX_RPATH_ORIGIN}/../lib") - target_link_libraries(sherpa-onnx-microphone-offline "-Wl,-rpath,${SHERPA_ONNX_RPATH_ORIGIN}/../../../sherpa_onnx/lib") + foreach(exe IN LISTS exes) + target_link_libraries(${exe} "-Wl,-rpath,${SHERPA_ONNX_RPATH_ORIGIN}/../lib") + target_link_libraries(${exe} "-Wl,-rpath,${SHERPA_ONNX_RPATH_ORIGIN}/../../../sherpa_onnx/lib") + endforeach() if(SHERPA_ONNX_ENABLE_PYTHON) - target_link_libraries(sherpa-onnx-microphone "-Wl,-rpath,${SHERPA_ONNX_RPATH_ORIGIN}/../lib/python${PYTHON_VERSION}/site-packages/sherpa_onnx/lib") - target_link_libraries(sherpa-onnx-microphone-offline "-Wl,-rpath,${SHERPA_ONNX_RPATH_ORIGIN}/../lib/python${PYTHON_VERSION}/site-packages/sherpa_onnx/lib") + + foreach(exe IN LISTS exes) + target_link_libraries(${exe} "-Wl,-rpath,${SHERPA_ONNX_RPATH_ORIGIN}/../lib/python${PYTHON_VERSION}/site-packages/sherpa_onnx/lib") + endforeach() endif() endif() install( - TARGETS - sherpa-onnx-microphone - sherpa-onnx-microphone-offline + TARGETS ${exes} DESTINATION bin ) @@ -269,6 +285,7 @@ endif() if(SHERPA_ONNX_ENABLE_TESTS) set(sherpa_onnx_test_srcs cat-test.cc + circular-buffer-test.cc context-graph-test.cc packed-sequence-test.cc pad-sequence-test.cc diff --git a/sherpa-onnx/csrc/README.md b/sherpa-onnx/csrc/README.md new file mode 100644 index 00000000..f073bb06 --- /dev/null +++ b/sherpa-onnx/csrc/README.md @@ -0,0 +1,29 @@ +# File descriptions + +- [./sherpa-onnx-alsa.cc](./sherpa-onnx-alsa.cc) For Linux only, especially for + embedded Linux, e.g., Raspberry Pi; it uses a streaming model for real-time + speech recognition with a microphone. + +- [./sherpa-onnx-microphone.cc](./sherpa-onnx-microphone.cc) + For Linux/Windows/macOS; it uses a streaming model for real-time speech + recognition with a microphone. + +- [./sherpa-onnx-microphone-offline.cc](./sherpa-onnx-microphone-offline.cc) + For Linux/Windows/macOS; it uses a non-streaming model for speech + recognition with a microphone. + +- [./sherpa-onnx.cc](./sherpa-onnx.cc) + It uses a streaming model to decode wave files + +- [./sherpa-onnx-offline.cc](./sherpa-onnx-offline.cc) + It uses a non-streaming model to decode wave files + +- [./online-websocket-server.cc](./online-websocket-server.cc) + WebSocket server for streaming models. + +- [./offline-websocket-server.cc](./offline-websocket-server.cc) + WebSocket server for non-streaming models. + +- [./sherpa-onnx-vad-microphone.cc](./sherpa-onnx-vad-microphone.cc) + Use silero VAD to detect speeches with a microphone. + diff --git a/sherpa-onnx/csrc/circular-buffer-test.cc b/sherpa-onnx/csrc/circular-buffer-test.cc new file mode 100644 index 00000000..66f9e99a --- /dev/null +++ b/sherpa-onnx/csrc/circular-buffer-test.cc @@ -0,0 +1,150 @@ +// sherpa-onnx/csrc/circular-buffer-test.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-onnx/csrc/circular-buffer.h" + +#include + +#include "gtest/gtest.h" +#include "sherpa-onnx/csrc/macros.h" + +namespace sherpa_onnx { + +TEST(CircularBuffer, Push) { + CircularBuffer buffer(10); + EXPECT_EQ(buffer.Size(), 0); + EXPECT_EQ(buffer.Head(), 0); + EXPECT_EQ(buffer.Tail(), 0); + + std::vector a = {0, 1, 2, 3, 4, 5}; + buffer.Push(a.data(), a.size()); + + EXPECT_EQ(buffer.Size(), 6); + EXPECT_EQ(buffer.Head(), 0); + EXPECT_EQ(buffer.Tail(), 6); + + auto c = buffer.Get(0, a.size()); + EXPECT_EQ(a.size(), c.size()); + for (int32_t i = 0; i != a.size(); ++i) { + EXPECT_EQ(a[i], c[i]); + } + + std::vector d = {-6, -7, -8, -9}; + buffer.Push(d.data(), d.size()); + + c = buffer.Get(a.size(), d.size()); + EXPECT_EQ(d.size(), c.size()); + for (int32_t i = 0; i != d.size(); ++i) { + EXPECT_EQ(d[i], c[i]); + } +} + +TEST(CircularBuffer, PushAndPop) { + CircularBuffer buffer(5); + std::vector a = {0, 1, 2, 3}; + buffer.Push(a.data(), a.size()); + + EXPECT_EQ(buffer.Size(), 4); + EXPECT_EQ(buffer.Head(), 0); + EXPECT_EQ(buffer.Tail(), 4); + + buffer.Pop(2); + + EXPECT_EQ(buffer.Size(), 2); + EXPECT_EQ(buffer.Head(), 2); + EXPECT_EQ(buffer.Tail(), 4); + + auto c = buffer.Get(2, 2); + EXPECT_EQ(c.size(), 2); + EXPECT_EQ(c[0], 2); + EXPECT_EQ(c[1], 3); + + a = {10, 20, 30}; + buffer.Push(a.data(), a.size()); + EXPECT_EQ(buffer.Size(), 5); + EXPECT_EQ(buffer.Head(), 2); + EXPECT_EQ(buffer.Tail(), 7); + + c = buffer.Get(2, 5); + EXPECT_EQ(c.size(), 5); + EXPECT_EQ(c[0], 2); + EXPECT_EQ(c[1], 3); + EXPECT_EQ(c[2], 10); + EXPECT_EQ(c[3], 20); + EXPECT_EQ(c[4], 30); + + c = buffer.Get(3, 4); + EXPECT_EQ(c.size(), 4); + EXPECT_EQ(c[0], 3); + EXPECT_EQ(c[1], 10); + EXPECT_EQ(c[2], 20); + EXPECT_EQ(c[3], 30); + + c = buffer.Get(4, 3); + EXPECT_EQ(c.size(), 3); + EXPECT_EQ(c[0], 10); + EXPECT_EQ(c[1], 20); + EXPECT_EQ(c[2], 30); + + buffer.Pop(4); + EXPECT_EQ(buffer.Size(), 1); + EXPECT_EQ(buffer.Head(), 6); + EXPECT_EQ(buffer.Tail(), 7); + + c = buffer.Get(6, 1); + EXPECT_EQ(c.size(), 1); + EXPECT_EQ(c[0], 30); + + a = {100, 200, 300, 400}; + buffer.Push(a.data(), a.size()); + EXPECT_EQ(buffer.Size(), 5); + + EXPECT_EQ(buffer.Size(), 5); + EXPECT_EQ(buffer.Head(), 6); + EXPECT_EQ(buffer.Tail(), 11); + + c = buffer.Get(6, 5); + EXPECT_EQ(c.size(), 5); + EXPECT_EQ(c[0], 30); + EXPECT_EQ(c[1], 100); + EXPECT_EQ(c[2], 200); + EXPECT_EQ(c[3], 300); + EXPECT_EQ(c[4], 400); + + buffer.Pop(3); + EXPECT_EQ(buffer.Size(), 2); + EXPECT_EQ(buffer.Head(), 9); + EXPECT_EQ(buffer.Tail(), 11); + + c = buffer.Get(10, 1); + EXPECT_EQ(c.size(), 1); + EXPECT_EQ(c[0], 400); + + a = {1000, 2000, 3000}; + buffer.Push(a.data(), a.size()); + + EXPECT_EQ(buffer.Size(), 5); + EXPECT_EQ(buffer.Head(), 9); + EXPECT_EQ(buffer.Tail(), 14); + + buffer.Pop(1); + + EXPECT_EQ(buffer.Size(), 4); + EXPECT_EQ(buffer.Head(), 10); + EXPECT_EQ(buffer.Tail(), 14); + + a = {4000}; + + buffer.Push(a.data(), a.size()); + EXPECT_EQ(buffer.Size(), 5); + EXPECT_EQ(buffer.Head(), 10); + EXPECT_EQ(buffer.Tail(), 15); + + c = buffer.Get(13, 2); + EXPECT_EQ(c.size(), 2); + EXPECT_EQ(c[0], 3000); + EXPECT_EQ(c[1], 4000); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/circular-buffer.cc b/sherpa-onnx/csrc/circular-buffer.cc new file mode 100644 index 00000000..b9702ace --- /dev/null +++ b/sherpa-onnx/csrc/circular-buffer.cc @@ -0,0 +1,96 @@ +// sherpa-onnx/csrc/circular-buffer.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-onnx/csrc/circular-buffer.h" + +#include + +#include "sherpa-onnx/csrc/macros.h" + +namespace sherpa_onnx { + +CircularBuffer::CircularBuffer(int32_t capacity) { + if (capacity <= 0) { + SHERPA_ONNX_LOGE("Please specify a positive capacity. Given: %d\n", + capacity); + exit(-1); + } + buffer_.resize(capacity); +} + +void CircularBuffer::Push(const float *p, int32_t n) { + int32_t capacity = buffer_.size(); + int32_t size = Size(); + if (n + size > capacity) { + SHERPA_ONNX_LOGE("Overflow! n: %d, size: %d, n+size: %d, capacity: %d", n, + size, n + size, capacity); + exit(-1); + } + + int32_t start = tail_ % capacity; + + tail_ += n; + + if (start + n < capacity) { + std::copy(p, p + n, buffer_.begin() + start); + return; + } + + int32_t part1_size = capacity - start; + + std::copy(p, p + part1_size, buffer_.begin() + start); + + std::copy(p + part1_size, p + n, buffer_.begin()); +} + +std::vector CircularBuffer::Get(int32_t start_index, int32_t n) const { + if (start_index < head_ || start_index >= tail_) { + SHERPA_ONNX_LOGE("Invalid start_index: %d. head_: %d, tail_: %d", + start_index, head_, tail_); + return {}; + } + + int32_t size = Size(); + if (n < 0 || n > size) { + SHERPA_ONNX_LOGE("Invalid n: %d. size: %d", n, size); + return {}; + } + + int32_t capacity = buffer_.size(); + + if (start_index - head_ + n > size) { + SHERPA_ONNX_LOGE("Invalid start_index: %d and n: %d. head_: %d, size: %d", + start_index, n, head_, size); + return {}; + } + + int32_t start = start_index % capacity; + + if (start + n < capacity) { + return {buffer_.begin() + start, buffer_.begin() + start + n}; + } + + std::vector ans(n); + + std::copy(buffer_.begin() + start, buffer_.end(), ans.begin()); + + int32_t part1_size = capacity - start; + int32_t part2_size = n - part1_size; + std::copy(buffer_.begin(), buffer_.begin() + part2_size, + ans.begin() + part1_size); + + return ans; +} + +void CircularBuffer::Pop(int32_t n) { + int32_t size = Size(); + if (n < 0 || n > size) { + SHERPA_ONNX_LOGE("Invalid n: %d. size: %d", n, size); + return; + } + + head_ += n; +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/circular-buffer.h b/sherpa-onnx/csrc/circular-buffer.h new file mode 100644 index 00000000..6b0419e3 --- /dev/null +++ b/sherpa-onnx/csrc/circular-buffer.h @@ -0,0 +1,59 @@ +// sherpa-onnx/csrc/circular-buffer.h +// +// Copyright (c) 2023 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_CIRCULAR_BUFFER_H_ +#define SHERPA_ONNX_CSRC_CIRCULAR_BUFFER_H_ + +#include +#include + +namespace sherpa_onnx { + +class CircularBuffer { + public: + // Capacity of this buffer. Should be large enough. + // If it is full, we just print a message and exit the program. + explicit CircularBuffer(int32_t capacity); + + // Push an array + // + // @param p Pointer to the start address of the array + // @param n Number of elements in the array + // + // Note: If n + Size() > capacity, we print an error message and exit. + void Push(const float *p, int32_t n); + + // @param start_index Should in the range [head_, tail_) + // @param n Number of elements to get + // @return Return a vector of size n containing the requested elements + std::vector Get(int32_t start_index, int32_t n) const; + + // Remove n elements from the buffer + // + // @param n Should be in the range [0, size_] + void Pop(int32_t n); + + // Number of elements in the buffer. + int32_t Size() const { return tail_ - head_; } + + // Current position of the head + int32_t Head() const { return head_; } + + // Current position of the tail + int32_t Tail() const { return tail_; } + + void Reset() { + head_ = 0; + tail_ = 0; + } + + private: + std::vector buffer_; + + int32_t head_ = 0; // linear index; always increasing; never wraps around + int32_t tail_ = 0; // linear index, always increasing; never wraps around. +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_CIRCULAR_BUFFER_H_ diff --git a/sherpa-onnx/csrc/session.cc b/sherpa-onnx/csrc/session.cc index 80c9471d..fe747740 100644 --- a/sherpa-onnx/csrc/session.cc +++ b/sherpa-onnx/csrc/session.cc @@ -76,4 +76,8 @@ Ort::SessionOptions GetSessionOptions(const OnlineLMConfig &config) { return GetSessionOptionsImpl(config.lm_num_threads, config.lm_provider); } +Ort::SessionOptions GetSessionOptions(const VadModelConfig &config) { + return GetSessionOptionsImpl(config.num_threads, config.provider); +} + } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/session.h b/sherpa-onnx/csrc/session.h index 42c93e0b..f0f25b23 100644 --- a/sherpa-onnx/csrc/session.h +++ b/sherpa-onnx/csrc/session.h @@ -10,6 +10,7 @@ #include "sherpa-onnx/csrc/offline-model-config.h" #include "sherpa-onnx/csrc/online-lm-config.h" #include "sherpa-onnx/csrc/online-model-config.h" +#include "sherpa-onnx/csrc/vad-model-config.h" namespace sherpa_onnx { @@ -20,6 +21,8 @@ Ort::SessionOptions GetSessionOptions(const OfflineModelConfig &config); Ort::SessionOptions GetSessionOptions(const OfflineLMConfig &config); Ort::SessionOptions GetSessionOptions(const OnlineLMConfig &config); + +Ort::SessionOptions GetSessionOptions(const VadModelConfig &config); } // namespace sherpa_onnx #endif // SHERPA_ONNX_CSRC_SESSION_H_ diff --git a/sherpa-onnx/csrc/sherpa-onnx-vad-microphone.cc b/sherpa-onnx/csrc/sherpa-onnx-vad-microphone.cc new file mode 100644 index 00000000..1953645e --- /dev/null +++ b/sherpa-onnx/csrc/sherpa-onnx-vad-microphone.cc @@ -0,0 +1,159 @@ +// sherpa-onnx/csrc/sherpa-onnx-vad-microphone.cc +// +// Copyright (c) 2022-2023 Xiaomi Corporation + +#include +#include +#include + +#include +#include // NOLINT + +#include "portaudio.h" // NOLINT +#include "sherpa-onnx/csrc/circular-buffer.h" +#include "sherpa-onnx/csrc/microphone.h" +#include "sherpa-onnx/csrc/voice-activity-detector.h" + +bool stop = false; +std::mutex mutex; +sherpa_onnx::CircularBuffer buffer(16000 * 60); + +static int32_t RecordCallback(const void *input_buffer, + void * /*output_buffer*/, + unsigned long frames_per_buffer, // NOLINT + const PaStreamCallbackTimeInfo * /*time_info*/, + PaStreamCallbackFlags /*status_flags*/, + void *user_data) { + std::lock_guard lock(mutex); + buffer.Push(reinterpret_cast(input_buffer), frames_per_buffer); + + return stop ? paComplete : paContinue; +} + +static void Handler(int32_t sig) { + stop = true; + fprintf(stderr, "\nCaught Ctrl + C. Exiting...\n"); +} + +int32_t main(int32_t argc, char *argv[]) { + signal(SIGINT, Handler); + + const char *kUsageMessage = R"usage( +This program shows how to use VAD in sherpa-onnx. + + ./bin/sherpa-onnx-vad-microphone \ + --silero-vad-model=/path/to/silero_vad.onnx \ + --provider=cpu \ + --num-threads=1 + +Please download silero_vad.onnx from +https://github.com/snakers4/silero-vad/blob/master/files/silero_vad.onnx + +For instance, use +wget https://github.com/snakers4/silero-vad/raw/master/files/silero_vad.onnx +)usage"; + + sherpa_onnx::ParseOptions po(kUsageMessage); + sherpa_onnx::VadModelConfig config; + + config.Register(&po); + po.Read(argc, argv); + if (po.NumArgs() != 0) { + po.PrintUsage(); + exit(EXIT_FAILURE); + } + + fprintf(stderr, "%s\n", config.ToString().c_str()); + + if (!config.Validate()) { + fprintf(stderr, "Errors in config!\n"); + return -1; + } + + sherpa_onnx::Microphone mic; + + PaDeviceIndex num_devices = Pa_GetDeviceCount(); + fprintf(stderr, "Num devices: %d\n", num_devices); + + PaStreamParameters param; + + param.device = Pa_GetDefaultInputDevice(); + if (param.device == paNoDevice) { + fprintf(stderr, "No default input device found\n"); + exit(EXIT_FAILURE); + } + fprintf(stderr, "Use default device: %d\n", param.device); + + const PaDeviceInfo *info = Pa_GetDeviceInfo(param.device); + fprintf(stderr, " Name: %s\n", info->name); + fprintf(stderr, " Max input channels: %d\n", info->maxInputChannels); + + param.channelCount = 1; + param.sampleFormat = paFloat32; + + param.suggestedLatency = info->defaultLowInputLatency; + param.hostApiSpecificStreamInfo = nullptr; + float sample_rate = 16000; + + PaStream *stream; + PaError err = + Pa_OpenStream(&stream, ¶m, nullptr, /* &outputParameters, */ + sample_rate, + 0, // frames per buffer + paClipOff, // we won't output out of range samples + // so don't bother clipping them + RecordCallback, &config.silero_vad.window_size); + if (err != paNoError) { + fprintf(stderr, "portaudio error: %s\n", Pa_GetErrorText(err)); + exit(EXIT_FAILURE); + } + + err = Pa_StartStream(stream); + + auto vad = std::make_unique(config); + + fprintf(stderr, "Started\n"); + + if (err != paNoError) { + fprintf(stderr, "portaudio error: %s\n", Pa_GetErrorText(err)); + exit(EXIT_FAILURE); + } + + int32_t window_size = config.silero_vad.window_size; + bool printed = false; + + while (!stop) { + { + std::lock_guard lock(mutex); + + while (buffer.Size() >= window_size) { + std::vector samples = buffer.Get(buffer.Head(), window_size); + buffer.Pop(window_size); + vad->AcceptWaveform(samples.data(), samples.size()); + + if (vad->IsSpeechDetected() && !printed) { + printed = true; + fprintf(stderr, "\nDetected speech!\n"); + } + if (!vad->IsSpeechDetected()) { + printed = false; + } + + while (!vad->Empty()) { + float duration = vad->Front().samples.size() / sample_rate; + vad->Pop(); + fprintf(stderr, "Duration: %.3f seconds\n", duration); + } + } + } + Pa_Sleep(100); // sleep for 100ms + } + + err = Pa_CloseStream(stream); + if (err != paNoError) { + fprintf(stderr, "portaudio error: %s\n", Pa_GetErrorText(err)); + exit(EXIT_FAILURE); + } + + return 0; +} diff --git a/sherpa-onnx/csrc/silero-vad-model-config.cc b/sherpa-onnx/csrc/silero-vad-model-config.cc new file mode 100644 index 00000000..8419265f --- /dev/null +++ b/sherpa-onnx/csrc/silero-vad-model-config.cc @@ -0,0 +1,81 @@ +// sherpa-onnx/csrc/silero-vad-model-config.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-onnx/csrc/silero-vad-model-config.h" + +#include "sherpa-onnx/csrc/file-utils.h" +#include "sherpa-onnx/csrc/macros.h" + +namespace sherpa_onnx { + +void SileroVadModelConfig::Register(ParseOptions *po) { + po->Register("silero-vad-model", &model, "Path to silero VAD ONNX model."); + + po->Register("silero-vad-threshold", &threshold, + "Speech threshold. Silero VAD outputs speech probabilities for " + "each audio chunk, probabilities ABOVE this value are " + "considered as SPEECH. It is better to tune this parameter for " + "each dataset separately, but lazy " + "0.5 is pretty good for most datasets."); + + po->Register( + "silero-vad-min-silence-duration", &min_silence_duration, + "In seconds. In the end of each speech chunk wait for " + "--silero-vad-min-silence-duration seconds before separating it"); + + po->Register("silero-vad-min-speech-duration", &min_speech_duration, + "In seconds. In the end of each silence chunk wait for " + "--silero-vad-min-speech-duration seconds before separating it"); + + po->Register( + "silero-vad-window-size", &window_size, + "In samples. Audio chunks of --silero-vad-window-size samples are fed " + "to the silero VAD model. WARNING! Silero VAD models were trained using " + "512, 1024, 1536 samples for 16000 sample rate and 256, 512, 768 samples " + "for 8000 sample rate. Values other than these may affect model " + "perfomance!"); +} + +bool SileroVadModelConfig::Validate() const { + if (model.empty()) { + SHERPA_ONNX_LOGE("Please provide --silero-vad-model"); + return false; + } + + if (!FileExists(model)) { + SHERPA_ONNX_LOGE("Silero vad model file %s does not exist", model.c_str()); + return false; + } + + if (threshold < 0.01) { + SHERPA_ONNX_LOGE( + "Please use a larger value for --silero-vad-threshold. Given: %f", + threshold); + return false; + } + + if (threshold >= 1) { + SHERPA_ONNX_LOGE( + "Please use a smaller value for --silero-vad-threshold. Given: %f", + threshold); + return false; + } + + return true; +} + +std::string SileroVadModelConfig::ToString() const { + std::ostringstream os; + + os << "SilerVadModelConfig("; + os << "model=\"" << model << "\", "; + os << "threshold=" << threshold << ", "; + os << "min_silence_duration=" << min_silence_duration << ", "; + os << "min_speech_duration=" << min_speech_duration << ", "; + os << "window_size=" << window_size << ")"; + + return os.str(); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/silero-vad-model-config.h b/sherpa-onnx/csrc/silero-vad-model-config.h new file mode 100644 index 00000000..fc930963 --- /dev/null +++ b/sherpa-onnx/csrc/silero-vad-model-config.h @@ -0,0 +1,41 @@ +// sherpa-onnx/csrc/silero-vad-model-config.h +// +// Copyright (c) 2023 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_SILERO_VAD_MODEL_CONFIG_H_ +#define SHERPA_ONNX_CSRC_SILERO_VAD_MODEL_CONFIG_H_ + +#include + +#include "sherpa-onnx/csrc/parse-options.h" + +namespace sherpa_onnx { + +struct SileroVadModelConfig { + std::string model; + + // threshold to classify a segment as speech + // + // The predicted probability of a segment is larger than this + // value, then it is classified as speech. + float threshold = 0.5; + + float min_silence_duration = 0.5; // in seconds + + float min_speech_duration = 0.25; // in seconds + + // 512, 1024, 1536 samples for 16000 Hz + // 256, 512, 768 samples for 800 Hz + int window_size = 512; // in samples + + SileroVadModelConfig() = default; + + void Register(ParseOptions *po); + + bool Validate() const; + + std::string ToString() const; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_SILERO_VAD_MODEL_CONFIG_H_ diff --git a/sherpa-onnx/csrc/silero-vad-model.cc b/sherpa-onnx/csrc/silero-vad-model.cc new file mode 100644 index 00000000..462c8b3a --- /dev/null +++ b/sherpa-onnx/csrc/silero-vad-model.cc @@ -0,0 +1,281 @@ +// sherpa-onnx/csrc/silero-vad-model.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-onnx/csrc/silero-vad-model.h" + +#include +#include +#include + +#include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/csrc/onnx-utils.h" +#include "sherpa-onnx/csrc/session.h" + +namespace sherpa_onnx { + +class SileroVadModel::Impl { + public: + explicit Impl(const VadModelConfig &config) + : config_(config), + env_(ORT_LOGGING_LEVEL_ERROR), + sess_opts_(GetSessionOptions(config)), + allocator_{} { + auto buf = ReadFile(config.silero_vad.model); + Init(buf.data(), buf.size()); + + sample_rate_ = config.sample_rate; + if (sample_rate_ != 16000) { + SHERPA_ONNX_LOGE("Expected sample rate 16000. Given: %d", + config.sample_rate); + exit(-1); + } + + min_silence_samples_ = + sample_rate_ * config_.silero_vad.min_silence_duration; + + min_speech_samples_ = sample_rate_ * config_.silero_vad.min_speech_duration; + } + + void Reset() { + // 2 - number of LSTM layer + // 1 - batch size + // 64 - hidden dim + std::array shape{2, 1, 64}; + + Ort::Value h = + Ort::Value::CreateTensor(allocator_, shape.data(), shape.size()); + + Ort::Value c = + Ort::Value::CreateTensor(allocator_, shape.data(), shape.size()); + + Fill(&h, 0); + Fill(&c, 0); + + states_.clear(); + + states_.reserve(2); + states_.push_back(std::move(h)); + states_.push_back(std::move(c)); + + triggered_ = false; + current_sample_ = 0; + temp_start_ = 0; + temp_end_ = 0; + } + + bool IsSpeech(const float *samples, int32_t n) { + if (n != config_.silero_vad.window_size) { + SHERPA_ONNX_LOGE("n: %d != window_size: %d", n, + config_.silero_vad.window_size); + exit(-1); + } + + auto memory_info = + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); + + std::array x_shape = {1, n}; + + Ort::Value x = + Ort::Value::CreateTensor(memory_info, const_cast(samples), n, + x_shape.data(), x_shape.size()); + + int64_t sr_shape = 1; + Ort::Value sr = + Ort::Value::CreateTensor(memory_info, &sample_rate_, 1, &sr_shape, 1); + + std::array inputs = {std::move(x), std::move(sr), + std::move(states_[0]), + std::move(states_[1])}; + + auto out = + sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(), + output_names_ptr_.data(), output_names_ptr_.size()); + + states_[0] = std::move(out[1]); + states_[1] = std::move(out[2]); + + float prob = out[0].GetTensorData()[0]; + + float threshold = config_.silero_vad.threshold; + + current_sample_ += config_.silero_vad.window_size; + + if (prob > threshold && temp_end_ != 0) { + temp_end_ = 0; + } + + if (prob > threshold && temp_start_ == 0) { + // start speaking, but we require that it must satisfy + // min_speech_duration + temp_start_ = current_sample_; + return false; + } + + if (prob > threshold && temp_start_ != 0 && !triggered_) { + if (current_sample_ - temp_start_ < min_speech_samples_) { + return false; + } + + triggered_ = true; + + return true; + } + + if ((prob < threshold) && !triggered_) { + // silence + temp_start_ = 0; + temp_end_ = 0; + return false; + } + + if ((prob > threshold - 0.15) && triggered_) { + // speaking + return true; + } + + if ((prob > threshold) && !triggered_) { + // start speaking + triggered_ = true; + + return true; + } + + if ((prob < threshold) && triggered_) { + // stop to speak + if (temp_end_ == 0) { + temp_end_ = current_sample_; + } + + if (current_sample_ - temp_end_ < min_silence_samples_) { + // continue speaking + return true; + } + // stopped speaking + temp_start_ = 0; + temp_end_ = 0; + triggered_ = false; + return false; + } + + return false; + } + + int32_t WindowSize() const { return config_.silero_vad.window_size; } + + int32_t MinSilenceDurationSamples() const { return min_silence_samples_; } + + int32_t MinSpeechDurationSamples() const { return min_speech_samples_; } + + private: + void Init(void *model_data, size_t model_data_length) { + sess_ = std::make_unique(env_, model_data, model_data_length, + sess_opts_); + + GetInputNames(sess_.get(), &input_names_, &input_names_ptr_); + GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_); + Check(); + + Reset(); + } + + void Check() { + if (input_names_.size() != 4) { + SHERPA_ONNX_LOGE("Expect 4 inputs. Given: %d", + static_cast(input_names_.size())); + exit(-1); + } + + if (input_names_[0] != "input") { + SHERPA_ONNX_LOGE("Input[0]: %s. Expected: input", + input_names_[0].c_str()); + exit(-1); + } + + if (input_names_[1] != "sr") { + SHERPA_ONNX_LOGE("Input[1]: %s. Expected: sr", input_names_[1].c_str()); + exit(-1); + } + + if (input_names_[2] != "h") { + SHERPA_ONNX_LOGE("Input[2]: %s. Expected: h", input_names_[2].c_str()); + exit(-1); + } + + if (input_names_[3] != "c") { + SHERPA_ONNX_LOGE("Input[3]: %s. Expected: c", input_names_[3].c_str()); + exit(-1); + } + + // Now for outputs + if (output_names_.size() != 3) { + SHERPA_ONNX_LOGE("Expect 3 outputs. Given: %d", + static_cast(output_names_.size())); + exit(-1); + } + + if (output_names_[0] != "output") { + SHERPA_ONNX_LOGE("Output[0]: %s. Expected: output", + output_names_[0].c_str()); + exit(-1); + } + + if (output_names_[1] != "hn") { + SHERPA_ONNX_LOGE("Output[1]: %s. Expected: sr", output_names_[1].c_str()); + exit(-1); + } + + if (output_names_[2] != "cn") { + SHERPA_ONNX_LOGE("Output[2]: %s. Expected: sr", output_names_[2].c_str()); + exit(-1); + } + } + + private: + VadModelConfig config_; + + Ort::Env env_; + Ort::SessionOptions sess_opts_; + Ort::AllocatorWithDefaultOptions allocator_; + + std::unique_ptr sess_; + + std::vector input_names_; + std::vector input_names_ptr_; + + std::vector output_names_; + std::vector output_names_ptr_; + + std::vector states_; + int64_t sample_rate_; + int32_t min_silence_samples_; + int32_t min_speech_samples_; + + bool triggered_ = false; + int32_t current_sample_ = 0; + int32_t temp_start_ = 0; + int32_t temp_end_ = 0; +}; + +SileroVadModel::SileroVadModel(const VadModelConfig &config) + : impl_(std::make_unique(config)) {} + +SileroVadModel::~SileroVadModel() = default; + +void SileroVadModel::Reset() { return impl_->Reset(); } + +bool SileroVadModel::IsSpeech(const float *samples, int32_t n) { + return impl_->IsSpeech(samples, n); +} + +int32_t SileroVadModel::WindowSize() const { return impl_->WindowSize(); } + +int32_t SileroVadModel::MinSilenceDurationSamples() const { + return impl_->MinSilenceDurationSamples(); +} + +int32_t SileroVadModel::MinSpeechDurationSamples() const { + return impl_->MinSpeechDurationSamples(); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/silero-vad-model.h b/sherpa-onnx/csrc/silero-vad-model.h new file mode 100644 index 00000000..7dcf02fe --- /dev/null +++ b/sherpa-onnx/csrc/silero-vad-model.h @@ -0,0 +1,42 @@ +// sherpa-onnx/csrc/silero-vad-model.h +// +// Copyright (c) 2023 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_SILERO_VAD_MODEL_H_ +#define SHERPA_ONNX_CSRC_SILERO_VAD_MODEL_H_ + +#include + +#include "sherpa-onnx/csrc/vad-model.h" + +namespace sherpa_onnx { + +class SileroVadModel : public VadModel { + public: + explicit SileroVadModel(const VadModelConfig &config); + ~SileroVadModel() override; + + // reset the internal model states + void Reset() override; + + /** + * @param samples Pointer to a 1-d array containing audio samples. + * Each sample should be normalized to the range [-1, 1]. + * @param n Number of samples. + * + * @return Return true if speech is detected. Return false otherwise. + */ + bool IsSpeech(const float *samples, int32_t n) override; + + int32_t WindowSize() const override; + + int32_t MinSilenceDurationSamples() const override; + int32_t MinSpeechDurationSamples() const override; + + private: + class Impl; + std::unique_ptr impl_; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_SILERO_VAD_MODEL_H_ diff --git a/sherpa-onnx/csrc/vad-model-config.cc b/sherpa-onnx/csrc/vad-model-config.cc new file mode 100644 index 00000000..f02ad01c --- /dev/null +++ b/sherpa-onnx/csrc/vad-model-config.cc @@ -0,0 +1,44 @@ +// sherpa-onnx/csrc/vad-model-config.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-onnx/csrc/vad-model-config.h" + +#include +#include + +namespace sherpa_onnx { + +void VadModelConfig::Register(ParseOptions *po) { + silero_vad.Register(po); + + po->Register("vad-sample-rate", &sample_rate, + "Sample rate expected by the VAD model"); + + po->Register("vad-num-threads", &num_threads, + "Number of threads to run the VAD model"); + + po->Register("vad-provider", &provider, + "Specify a provider to run the VAD model. Supported values: " + "cpu, cuda, coreml"); + + po->Register("vad-debug", &debug, + "true to display debug information when loading vad models"); +} + +bool VadModelConfig::Validate() const { return silero_vad.Validate(); } + +std::string VadModelConfig::ToString() const { + std::ostringstream os; + + os << "VadModelConfig("; + os << "silero_vad=" << silero_vad.ToString() << ", "; + os << "sample_rate=" << sample_rate << ", "; + os << "num_threads=" << num_threads << ", "; + os << "provider=\"" << provider << "\", "; + os << "debug=" << (debug ? "True" : "False") << ")"; + + return os.str(); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/vad-model-config.h b/sherpa-onnx/csrc/vad-model-config.h new file mode 100644 index 00000000..725ab340 --- /dev/null +++ b/sherpa-onnx/csrc/vad-model-config.h @@ -0,0 +1,42 @@ +// sherpa-onnx/csrc/vad-model-config.h +// +// Copyright (c) 2023 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_VAD_MODEL_CONFIG_H_ +#define SHERPA_ONNX_CSRC_VAD_MODEL_CONFIG_H_ + +#include + +#include "sherpa-onnx/csrc/parse-options.h" +#include "sherpa-onnx/csrc/silero-vad-model-config.h" + +namespace sherpa_onnx { + +struct VadModelConfig { + SileroVadModelConfig silero_vad; + + int32_t sample_rate = 16000; + int32_t num_threads = 1; + std::string provider = "cpu"; + + // true to show debug information when loading models + bool debug = false; + + VadModelConfig() = default; + + VadModelConfig(const SileroVadModelConfig &silero_vad, int32_t sample_rate, + int32_t num_threads, const std::string &provider, bool debug) + : silero_vad(silero_vad), + sample_rate(sample_rate), + num_threads(num_threads), + provider(provider), + debug(debug) {} + + void Register(ParseOptions *po); + bool Validate() const; + + std::string ToString() const; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_VAD_MODEL_CONFIG_H_ diff --git a/sherpa-onnx/csrc/vad-model.cc b/sherpa-onnx/csrc/vad-model.cc new file mode 100644 index 00000000..47d3fc68 --- /dev/null +++ b/sherpa-onnx/csrc/vad-model.cc @@ -0,0 +1,16 @@ +// sherpa-onnx/csrc/vad-model.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-onnx/csrc/vad-model.h" + +#include "sherpa-onnx/csrc/silero-vad-model.h" + +namespace sherpa_onnx { + +std::unique_ptr VadModel::Create(const VadModelConfig &config) { + // TODO(fangjun): Support other VAD models. + return std::make_unique(config); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/vad-model.h b/sherpa-onnx/csrc/vad-model.h new file mode 100644 index 00000000..7227b1ff --- /dev/null +++ b/sherpa-onnx/csrc/vad-model.h @@ -0,0 +1,39 @@ +// sherpa-onnx/csrc/vad-model.h +// +// Copyright (c) 2023 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_VAD_MODEL_H_ +#define SHERPA_ONNX_CSRC_VAD_MODEL_H_ + +#include + +#include "sherpa-onnx/csrc/vad-model-config.h" + +namespace sherpa_onnx { + +class VadModel { + public: + virtual ~VadModel() = default; + + static std::unique_ptr Create(const VadModelConfig &config); + + // reset the internal model states + virtual void Reset() = 0; + + /** + * @param samples Pointer to a 1-d array containing audio samples. + * Each sample should be normalized to the range [-1, 1]. + * @param n Number of samples. Should be equal to WindowSize() + * + * @return Return true if speech is detected. Return false otherwise. + */ + virtual bool IsSpeech(const float *samples, int32_t n) = 0; + + virtual int32_t WindowSize() const = 0; + + virtual int32_t MinSilenceDurationSamples() const = 0; + virtual int32_t MinSpeechDurationSamples() const = 0; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_VAD_MODEL_H_ diff --git a/sherpa-onnx/csrc/voice-activity-detector.cc b/sherpa-onnx/csrc/voice-activity-detector.cc new file mode 100644 index 00000000..2109a0b9 --- /dev/null +++ b/sherpa-onnx/csrc/voice-activity-detector.cc @@ -0,0 +1,104 @@ +// sherpa-onnx/csrc/voice-activity-detector.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-onnx/csrc/voice-activity-detector.h" + +#include +#include + +#include "sherpa-onnx/csrc/circular-buffer.h" +#include "sherpa-onnx/csrc/vad-model.h" + +namespace sherpa_onnx { + +class VoiceActivityDetector::Impl { + public: + explicit Impl(const VadModelConfig &config, float buffer_size_in_seconds = 60) + : model_(VadModel::Create(config)), + config_(config), + buffer_(buffer_size_in_seconds * config.sample_rate) {} + + void AcceptWaveform(const float *samples, int32_t n) { + buffer_.Push(samples, n); + + bool is_speech = model_->IsSpeech(samples, n); + if (is_speech) { + if (start_ == -1) { + // beginning of speech + start_ = buffer_.Tail() - 2 * model_->WindowSize() - + model_->MinSpeechDurationSamples(); + } + } else { + // non-speech + if (start_ != -1) { + // end of speech, save the speech segment + int32_t end = buffer_.Tail() - model_->MinSilenceDurationSamples(); + + std::vector samples = buffer_.Get(start_, end - start_); + SpeechSegment segment; + + segment.start = start_; + segment.samples = std::move(samples); + + segments_.push(std::move(segment)); + + buffer_.Pop(end - buffer_.Head()); + } + + start_ = -1; + } + } + + bool Empty() const { return segments_.empty(); } + + void Pop() { segments_.pop(); } + + const SpeechSegment &Front() const { return segments_.front(); } + + void Reset() { + std::queue().swap(segments_); + + model_->Reset(); + buffer_.Reset(); + + start_ = -1; + } + + bool IsSpeechDetected() const { return start_ != -1; } + + private: + std::queue segments_; + + std::unique_ptr model_; + VadModelConfig config_; + CircularBuffer buffer_; + + int32_t start_ = -1; +}; + +VoiceActivityDetector::VoiceActivityDetector( + const VadModelConfig &config, float buffer_size_in_seconds /*= 60*/) + : impl_(std::make_unique(config, buffer_size_in_seconds)) {} + +VoiceActivityDetector::~VoiceActivityDetector() = default; + +void VoiceActivityDetector::AcceptWaveform(const float *samples, int32_t n) { + impl_->AcceptWaveform(samples, n); +} + +bool VoiceActivityDetector::Empty() const { return impl_->Empty(); } + +void VoiceActivityDetector::Pop() { impl_->Pop(); } + +const SpeechSegment &VoiceActivityDetector::Front() const { + return impl_->Front(); +} + +void VoiceActivityDetector::Reset() { impl_->Reset(); } + +bool VoiceActivityDetector::IsSpeechDetected() const { + return impl_->IsSpeechDetected(); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/voice-activity-detector.h b/sherpa-onnx/csrc/voice-activity-detector.h new file mode 100644 index 00000000..59483823 --- /dev/null +++ b/sherpa-onnx/csrc/voice-activity-detector.h @@ -0,0 +1,41 @@ +// sherpa-onnx/csrc/voice-activity-detector.h +// +// Copyright (c) 2023 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_VOICE_ACTIVITY_DETECTOR_H_ +#define SHERPA_ONNX_CSRC_VOICE_ACTIVITY_DETECTOR_H_ + +#include +#include + +#include "sherpa-onnx/csrc/vad-model-config.h" + +namespace sherpa_onnx { + +struct SpeechSegment { + int32_t start; // in samples + std::vector samples; +}; + +class VoiceActivityDetector { + public: + explicit VoiceActivityDetector(const VadModelConfig &config, + float buffer_size_in_seconds = 60); + ~VoiceActivityDetector(); + + void AcceptWaveform(const float *samples, int32_t n); + bool Empty() const; + void Pop(); + const SpeechSegment &Front() const; + + bool IsSpeechDetected() const; + + void Reset(); + + private: + class Impl; + std::unique_ptr impl_; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_VOICE_ACTIVITY_DETECTOR_H_ diff --git a/sherpa-onnx/python/csrc/CMakeLists.txt b/sherpa-onnx/python/csrc/CMakeLists.txt index d61e4303..1973a1c7 100644 --- a/sherpa-onnx/python/csrc/CMakeLists.txt +++ b/sherpa-onnx/python/csrc/CMakeLists.txt @@ -1,6 +1,7 @@ include_directories(${CMAKE_SOURCE_DIR}) pybind11_add_module(_sherpa_onnx + circular-buffer.cc display.cc endpoint.cc features.cc @@ -20,6 +21,10 @@ pybind11_add_module(_sherpa_onnx online-stream.cc online-transducer-model-config.cc sherpa-onnx.cc + silero-vad-model-config.cc + vad-model-config.cc + vad-model.cc + voice-activity-detector.cc ) if(APPLE) diff --git a/sherpa-onnx/python/csrc/circular-buffer.cc b/sherpa-onnx/python/csrc/circular-buffer.cc new file mode 100644 index 00000000..20ea4b51 --- /dev/null +++ b/sherpa-onnx/python/csrc/circular-buffer.cc @@ -0,0 +1,31 @@ +// sherpa-onnx/python/csrc/circular-buffer.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-onnx/python/csrc/circular-buffer.h" + +#include + +#include "sherpa-onnx/csrc/circular-buffer.h" + +namespace sherpa_onnx { + +void PybindCircularBuffer(py::module *m) { + using PyClass = CircularBuffer; + py::class_(*m, "CircularBuffer") + .def(py::init(), py::arg("capacity")) + .def( + "push", + [](PyClass &self, const std::vector &samples) { + self.Push(samples.data(), samples.size()); + }, + py::arg("samples")) + .def("get", &PyClass::Get, py::arg("start_index"), py::arg("n")) + .def("pop", &PyClass::Pop, py::arg("n")) + .def("reset", &PyClass::Reset) + .def_property_readonly("size", &PyClass::Size) + .def_property_readonly("head", &PyClass::Head) + .def_property_readonly("tail", &PyClass::Tail); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/python/csrc/circular-buffer.h b/sherpa-onnx/python/csrc/circular-buffer.h new file mode 100644 index 00000000..4c4383bd --- /dev/null +++ b/sherpa-onnx/python/csrc/circular-buffer.h @@ -0,0 +1,16 @@ +// sherpa-onnx/python/csrc/circular-buffer.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_PYTHON_CSRC_CIRCULAR_BUFFER_H_ +#define SHERPA_ONNX_PYTHON_CSRC_CIRCULAR_BUFFER_H_ + +#include "sherpa-onnx/python/csrc/sherpa-onnx.h" + +namespace sherpa_onnx { + +void PybindCircularBuffer(py::module *m); + +} + +#endif // SHERPA_ONNX_PYTHON_CSRC_CIRCULAR_BUFFER_H_ diff --git a/sherpa-onnx/python/csrc/sherpa-onnx.cc b/sherpa-onnx/python/csrc/sherpa-onnx.cc index 64f8aacf..98547df8 100644 --- a/sherpa-onnx/python/csrc/sherpa-onnx.cc +++ b/sherpa-onnx/python/csrc/sherpa-onnx.cc @@ -4,6 +4,7 @@ #include "sherpa-onnx/python/csrc/sherpa-onnx.h" +#include "sherpa-onnx/python/csrc/circular-buffer.h" #include "sherpa-onnx/python/csrc/display.h" #include "sherpa-onnx/python/csrc/endpoint.h" #include "sherpa-onnx/python/csrc/features.h" @@ -15,6 +16,9 @@ #include "sherpa-onnx/python/csrc/online-model-config.h" #include "sherpa-onnx/python/csrc/online-recognizer.h" #include "sherpa-onnx/python/csrc/online-stream.h" +#include "sherpa-onnx/python/csrc/vad-model-config.h" +#include "sherpa-onnx/python/csrc/vad-model.h" +#include "sherpa-onnx/python/csrc/voice-activity-detector.h" namespace sherpa_onnx { @@ -34,6 +38,11 @@ PYBIND11_MODULE(_sherpa_onnx, m) { PybindOfflineLMConfig(&m); PybindOfflineModelConfig(&m); PybindOfflineRecognizer(&m); + + PybindVadModelConfig(&m); + PybindVadModel(&m); + PybindCircularBuffer(&m); + PybindVoiceActivityDetector(&m); } } // namespace sherpa_onnx diff --git a/sherpa-onnx/python/csrc/silero-vad-model-config.cc b/sherpa-onnx/python/csrc/silero-vad-model-config.cc new file mode 100644 index 00000000..3065a02d --- /dev/null +++ b/sherpa-onnx/python/csrc/silero-vad-model-config.cc @@ -0,0 +1,43 @@ +// sherpa-onnx/python/csrc/silero-vad-model-config.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-onnx/python/csrc/silero-vad-model-config.h" + +#include +#include + +#include "sherpa-onnx/csrc/silero-vad-model-config.h" + +namespace sherpa_onnx { + +void PybindSileroVadModelConfig(py::module *m) { + using PyClass = SileroVadModelConfig; + py::class_(*m, "SileroVadModelConfig") + .def(py::init<>()) + .def(py::init([](const std::string &model, float threshold, + float min_silence_duration, float min_speech_duration, + int32_t window_size) -> std::unique_ptr { + auto ans = std::make_unique(); + + ans->model = model; + ans->threshold = threshold; + ans->min_silence_duration = min_silence_duration; + ans->min_speech_duration = min_speech_duration; + ans->window_size = window_size; + + return ans; + }), + py::arg("model"), py::arg("threshold") = 0.5, + py::arg("min_silence_duration") = 0.5, + py::arg("min_speech_duration") = 0.25, py::arg("window_size") = 512) + .def_readwrite("model", &PyClass::model) + .def_readwrite("threshold", &PyClass::threshold) + .def_readwrite("min_silence_duration", &PyClass::min_silence_duration) + .def_readwrite("min_speech_duration", &PyClass::min_speech_duration) + .def_readwrite("window_size", &PyClass::window_size) + .def("__str__", &PyClass::ToString) + .def("validate", &PyClass::Validate); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/python/csrc/silero-vad-model-config.h b/sherpa-onnx/python/csrc/silero-vad-model-config.h new file mode 100644 index 00000000..52997367 --- /dev/null +++ b/sherpa-onnx/python/csrc/silero-vad-model-config.h @@ -0,0 +1,16 @@ +// sherpa-onnx/python/csrc/silero-vad-model-config.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_PYTHON_CSRC_SILERO_VAD_MODEL_CONFIG_H_ +#define SHERPA_ONNX_PYTHON_CSRC_SILERO_VAD_MODEL_CONFIG_H_ + +#include "sherpa-onnx/python/csrc/sherpa-onnx.h" + +namespace sherpa_onnx { + +void PybindSileroVadModelConfig(py::module *m); + +} + +#endif // SHERPA_ONNX_PYTHON_CSRC_SILERO_VAD_MODEL_CONFIG_H_ diff --git a/sherpa-onnx/python/csrc/vad-model-config.cc b/sherpa-onnx/python/csrc/vad-model-config.cc new file mode 100644 index 00000000..9be0ba98 --- /dev/null +++ b/sherpa-onnx/python/csrc/vad-model-config.cc @@ -0,0 +1,34 @@ +// sherpa-onnx/python/csrc/vad-model-config.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-onnx/python/csrc/vad-model-config.h" + +#include + +#include "sherpa-onnx/csrc/vad-model-config.h" +#include "sherpa-onnx/python/csrc/silero-vad-model-config.h" + +namespace sherpa_onnx { + +void PybindVadModelConfig(py::module *m) { + PybindSileroVadModelConfig(m); + + using PyClass = VadModelConfig; + py::class_(*m, "VadModelConfig") + .def(py::init<>()) + .def(py::init(), + py::arg("silero_vad"), py::arg("sample_rate") = 16000, + py::arg("num_threads") = 1, py::arg("provider") = "cpu", + py::arg("debug") = false) + .def_readwrite("silero_vad", &PyClass::silero_vad) + .def_readwrite("sample_rate", &PyClass::sample_rate) + .def_readwrite("num_threads", &PyClass::num_threads) + .def_readwrite("provider", &PyClass::provider) + .def_readwrite("debug", &PyClass::debug) + .def("__str__", &PyClass::ToString) + .def("validate", &PyClass::Validate); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/python/csrc/vad-model-config.h b/sherpa-onnx/python/csrc/vad-model-config.h new file mode 100644 index 00000000..c19842c1 --- /dev/null +++ b/sherpa-onnx/python/csrc/vad-model-config.h @@ -0,0 +1,16 @@ +// sherpa-onnx/python/csrc/vad-model-config.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_PYTHON_CSRC_VAD_MODEL_CONFIG_H_ +#define SHERPA_ONNX_PYTHON_CSRC_VAD_MODEL_CONFIG_H_ + +#include "sherpa-onnx/python/csrc/sherpa-onnx.h" + +namespace sherpa_onnx { + +void PybindVadModelConfig(py::module *m); + +} + +#endif // SHERPA_ONNX_PYTHON_CSRC_VAD_MODEL_CONFIG_H_ diff --git a/sherpa-onnx/python/csrc/vad-model.cc b/sherpa-onnx/python/csrc/vad-model.cc new file mode 100644 index 00000000..11c81cbb --- /dev/null +++ b/sherpa-onnx/python/csrc/vad-model.cc @@ -0,0 +1,29 @@ +// sherpa-onnx/python/csrc/vad-model.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-onnx/python/csrc/vad-model.h" + +#include + +#include "sherpa-onnx/csrc/vad-model.h" + +namespace sherpa_onnx { + +void PybindVadModel(py::module *m) { + using PyClass = VadModel; + py::class_(*m, "VadModel") + .def_static("create", &PyClass::Create, py::arg("config")) + .def("reset", &PyClass::Reset) + .def( + "is_speech", + [](PyClass &self, const std::vector &samples) -> bool { + return self.IsSpeech(samples.data(), samples.size()); + }, + py::arg("samples")) + .def("window_size", &PyClass::WindowSize) + .def("min_silence_duration_samples", &PyClass::MinSilenceDurationSamples) + .def("min_speech_duration_samples", &PyClass::MinSpeechDurationSamples); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/python/csrc/vad-model.h b/sherpa-onnx/python/csrc/vad-model.h new file mode 100644 index 00000000..79c8debd --- /dev/null +++ b/sherpa-onnx/python/csrc/vad-model.h @@ -0,0 +1,16 @@ +// sherpa-onnx/python/csrc/vad-model.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_PYTHON_CSRC_VAD_MODEL_H_ +#define SHERPA_ONNX_PYTHON_CSRC_VAD_MODEL_H_ + +#include "sherpa-onnx/python/csrc/sherpa-onnx.h" + +namespace sherpa_onnx { + +void PybindVadModel(py::module *m); + +} + +#endif // SHERPA_ONNX_PYTHON_CSRC_VAD_MODEL_H_ diff --git a/sherpa-onnx/python/csrc/voice-activity-detector.cc b/sherpa-onnx/python/csrc/voice-activity-detector.cc new file mode 100644 index 00000000..237e32ab --- /dev/null +++ b/sherpa-onnx/python/csrc/voice-activity-detector.cc @@ -0,0 +1,41 @@ +// sherpa-onnx/python/csrc/voice-activity-detector.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-onnx/python/csrc/voice-activity-detector.h" + +#include + +#include "sherpa-onnx/csrc/voice-activity-detector.h" + +namespace sherpa_onnx { + +void PybindSpeechSegment(py::module *m) { + using PyClass = SpeechSegment; + py::class_(*m, "SpeechSegment") + .def_property_readonly("start", + [](const PyClass &self) { return self.start; }) + .def_property_readonly("samples", + [](const PyClass &self) { return self.samples; }); +} + +void PybindVoiceActivityDetector(py::module *m) { + PybindSpeechSegment(m); + using PyClass = VoiceActivityDetector; + py::class_(*m, "VoiceActivityDetector") + .def(py::init(), py::arg("config"), + py::arg("buffer_size_in_seconds") = 60) + .def( + "accept_waveform", + [](PyClass &self, const std::vector &samples) { + self.AcceptWaveform(samples.data(), samples.size()); + }, + py::arg("samples")) + .def("empty", &PyClass::Empty) + .def("pop", &PyClass::Pop) + .def("is_speech_detected", &PyClass::IsSpeechDetected) + .def("reset", &PyClass::Reset) + .def_property_readonly("front", &PyClass::Front); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/python/csrc/voice-activity-detector.h b/sherpa-onnx/python/csrc/voice-activity-detector.h new file mode 100644 index 00000000..9e460b2f --- /dev/null +++ b/sherpa-onnx/python/csrc/voice-activity-detector.h @@ -0,0 +1,16 @@ +// sherpa-onnx/python/csrc/voice-activity-detector.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_PYTHON_CSRC_VOICE_ACTIVITY_DETECTOR_H_ +#define SHERPA_ONNX_PYTHON_CSRC_VOICE_ACTIVITY_DETECTOR_H_ + +#include "sherpa-onnx/python/csrc/sherpa-onnx.h" + +namespace sherpa_onnx { + +void PybindVoiceActivityDetector(py::module *m); + +} + +#endif // SHERPA_ONNX_PYTHON_CSRC_VOICE_ACTIVITY_DETECTOR_H_ diff --git a/sherpa-onnx/python/sherpa_onnx/__init__.py b/sherpa-onnx/python/sherpa_onnx/__init__.py index b21156c7..57a2302e 100644 --- a/sherpa-onnx/python/sherpa_onnx/__init__.py +++ b/sherpa-onnx/python/sherpa_onnx/__init__.py @@ -1,6 +1,16 @@ from typing import Dict, List, Optional -from _sherpa_onnx import Display, OfflineStream, OnlineStream +from _sherpa_onnx import ( + CircularBuffer, + Display, + OfflineStream, + OnlineStream, + SileroVadModelConfig, + SpeechSegment, + VadModel, + VadModelConfig, + VoiceActivityDetector, +) from .offline_recognizer import OfflineRecognizer from .online_recognizer import OnlineRecognizer