From 8e51a975508fd69d3eed53d5098862201889fafd Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 1 Apr 2025 15:56:56 +0800 Subject: [PATCH] Add C++ runtime for silero_vad with RKNN (#2078) --- c-api-examples/vad-whisper-c-api.c | 11 +- scripts/gtcrn/show.py | 75 +++ scripts/silero_vad/v4/export-onnx.py | 81 ++- scripts/silero_vad/v4/show.py | 2 +- scripts/silero_vad/v4/test-on-rk3588-board.py | 141 ++++++ scripts/silero_vad/v4/test-onnx.py | 3 + sherpa-onnx/csrc/CMakeLists.txt | 6 +- .../csrc/rknn/silero-vad-model-rknn.cc | 470 ++++++++++++++++++ sherpa-onnx/csrc/rknn/silero-vad-model-rknn.h | 53 ++ .../csrc/sherpa-onnx-vad-alsa-offline-asr.cc | 4 +- sherpa-onnx/csrc/vad-model-config.cc | 25 +- sherpa-onnx/csrc/vad-model.cc | 12 +- 12 files changed, 867 insertions(+), 16 deletions(-) create mode 100755 scripts/gtcrn/show.py create mode 100755 scripts/silero_vad/v4/test-on-rk3588-board.py create mode 100644 sherpa-onnx/csrc/rknn/silero-vad-model-rknn.cc create mode 100644 sherpa-onnx/csrc/rknn/silero-vad-model-rknn.h diff --git a/c-api-examples/vad-whisper-c-api.c b/c-api-examples/vad-whisper-c-api.c index c8860013..48e6b784 100644 --- a/c-api-examples/vad-whisper-c-api.c +++ b/c-api-examples/vad-whisper-c-api.c @@ -100,12 +100,11 @@ int32_t main() { while (!is_eof) { if (i + window_size < wave->num_samples) { - SherpaOnnxVoiceActivityDetectorAcceptWaveform(vad, wave->samples + i, - window_size); - } - else { - SherpaOnnxVoiceActivityDetectorFlush(vad); - is_eof = 1; + SherpaOnnxVoiceActivityDetectorAcceptWaveform(vad, wave->samples + i, + window_size); + } else { + SherpaOnnxVoiceActivityDetectorFlush(vad); + is_eof = 1; } while (!SherpaOnnxVoiceActivityDetectorEmpty(vad)) { const SherpaOnnxSpeechSegment *segment = diff --git a/scripts/gtcrn/show.py b/scripts/gtcrn/show.py new file mode 100755 index 00000000..0cbd87fc --- /dev/null +++ b/scripts/gtcrn/show.py @@ -0,0 +1,75 @@ +#!/usr/bin/env python3 +# Copyright 2025 Xiaomi Corp. (authors: Fangjun Kuang) + +import onnxruntime +import onnx + +""" +[key: "model_type" +value: "gtcrn" +, key: "comment" +value: "gtcrn_simple" +, key: "version" +value: "1" +, key: "sample_rate" +value: "16000" +, key: "model_url" +value: "https://github.com/Xiaobin-Rong/gtcrn/blob/main/stream/onnx_models/gtcrn_simple.onnx" +, key: "maintainer" +value: "k2-fsa" +, key: "comment2" +value: "Please see also https://github.com/Xiaobin-Rong/gtcrn" +, key: "conv_cache_shape" +value: "2,1,16,16,33" +, key: "tra_cache_shape" +value: "2,3,1,1,16" +, key: "inter_cache_shape" +value: "2,1,33,16" +, key: "n_fft" +value: "512" +, key: "hop_length" +value: "256" +, key: "window_length" +value: "512" +, key: "window_type" +value: "hann_sqrt" +] +""" + +""" +NodeArg(name='mix', type='tensor(float)', shape=[1, 257, 1, 2]) +NodeArg(name='conv_cache', type='tensor(float)', shape=[2, 1, 16, 16, 33]) +NodeArg(name='tra_cache', type='tensor(float)', shape=[2, 3, 1, 1, 16]) +NodeArg(name='inter_cache', type='tensor(float)', shape=[2, 1, 33, 16]) +----- +NodeArg(name='enh', type='tensor(float)', shape=[1, 257, 1, 2]) +NodeArg(name='conv_cache_out', type='tensor(float)', shape=[2, 1, 16, 16, 33]) +NodeArg(name='tra_cache_out', type='tensor(float)', shape=[2, 3, 1, 1, 16]) +NodeArg(name='inter_cache_out', type='tensor(float)', shape=[2, 1, 33, 16]) +""" + + +def show(filename): + model = onnx.load(filename) + print(model.metadata_props) + + session_opts = onnxruntime.SessionOptions() + session_opts.log_severity_level = 3 + sess = onnxruntime.InferenceSession( + filename, session_opts, providers=["CPUExecutionProvider"] + ) + for i in sess.get_inputs(): + print(i) + + print("-----") + + for i in sess.get_outputs(): + print(i) + + +def main(): + show("./gtcrn_simple.onnx") + + +if __name__ == "__main__": + main() diff --git a/scripts/silero_vad/v4/export-onnx.py b/scripts/silero_vad/v4/export-onnx.py index 075d711d..f6b8c04d 100755 --- a/scripts/silero_vad/v4/export-onnx.py +++ b/scripts/silero_vad/v4/export-onnx.py @@ -5,15 +5,94 @@ import onnx import torch from onnxsim import simplify +import torch +from torch import Tensor + + +def simple_pad(x: Tensor, pad: int) -> Tensor: + # _0 = torch.slice(torch.slice(torch.slice(x), 1), 2, 1, torch.add(1, pad)) + _0 = x[:, :, 1 : 1 + pad] + + left_pad = torch.flip(_0, [-1]) + # _1 = torch.slice(torch.slice(torch.slice(x), 1), 2, torch.sub(-1, pad), -1) + + _1 = x[:, :, (-1 - pad) : -1] + + right_pad = torch.flip(_1, [-1]) + _2 = torch.cat([left_pad, x, right_pad], 2) + return _2 + + +class MyModule(torch.nn.Module): + def __init__(self, m): + super().__init__() + self.m = m + + def adaptive_normalization_forward(self, spect): + m = self.m._model.adaptive_normalization + _0 = simple_pad + + # Note(fangjun): rknn uses fp16 by default, whose max value is 65504 + # so we need to re-write the computation for spect0 + # spect0 = torch.log1p(torch.mul(spect, 1048576)) + spect0 = torch.log1p(spect) + 13.86294 + + _1 = torch.eq(len(spect0.shape), 2) + if _1: + _2 = torch.unsqueeze(spect0, 0) + spect1 = _2 + else: + spect1 = spect0 + mean = torch.mean(spect1, [1], True) + to_pad = m.to_pad + mean0 = _0( + mean, + to_pad, + ) + filter_ = m.filter_ + mean1 = torch.conv1d(mean0, filter_) + mean_mean = torch.mean(mean1, [-1], True) + spect2 = torch.add(spect1, torch.neg(mean_mean)) + return spect2 + + def forward(self, x: torch.Tensor, h: torch.Tensor, c: torch.Tensor): + m = self.m._model + + feature_extractor = m.feature_extractor + x0 = (feature_extractor).forward( + x, + ) + norm = self.adaptive_normalization_forward(x0) + x1 = torch.cat([x0, norm], 1) + first_layer = m.first_layer + x2 = (first_layer).forward( + x1, + ) + encoder = m.encoder + x3 = (encoder).forward( + x2, + ) + decoder = m.decoder + x4, h0, c0, = (decoder).forward( + x3, + h, + c, + ) + _0 = torch.mean(torch.squeeze(x4, 1), [1]) + out = torch.unsqueeze(_0, 1) + return (out, h0, c0) + @torch.no_grad() def main(): m = torch.jit.load("./silero_vad.jit") + m = MyModule(m) x = torch.rand((1, 512), dtype=torch.float32) h = torch.rand((2, 1, 64), dtype=torch.float32) c = torch.rand((2, 1, 64), dtype=torch.float32) + m = torch.jit.script(m) torch.onnx.export( - m._model, + m, (x, h, c), "m.onnx", input_names=["x", "h", "c"], diff --git a/scripts/silero_vad/v4/show.py b/scripts/silero_vad/v4/show.py index 6a76b98e..80fbc46d 100755 --- a/scripts/silero_vad/v4/show.py +++ b/scripts/silero_vad/v4/show.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang) +# Copyright 2025 Xiaomi Corp. (authors: Fangjun Kuang) import onnxruntime import onnx diff --git a/scripts/silero_vad/v4/test-on-rk3588-board.py b/scripts/silero_vad/v4/test-on-rk3588-board.py new file mode 100755 index 00000000..545569eb --- /dev/null +++ b/scripts/silero_vad/v4/test-on-rk3588-board.py @@ -0,0 +1,141 @@ +#!/usr/bin/env python3 +# Copyright 2025 Xiaomi Corp. (authors: Fangjun Kuang) + +# Please run this file on your rk3588 board + +try: + from rknnlite.api import RKNNLite +except: + print("Please run this file on your board (linux + aarch64 + npu)") + print("You need to install rknn_toolkit_lite2") + print( + " from https://github.com/airockchip/rknn-toolkit2/tree/master/rknn-toolkit-lite2/packages" + ) + print( + "https://github.com/airockchip/rknn-toolkit2/blob/v2.1.0/rknn-toolkit-lite2/packages/rknn_toolkit_lite2-2.1.0-cp310-cp310-linux_aarch64.whl" + ) + print("is known to work") + raise + +import time +from pathlib import Path +from typing import Tuple + +import numpy as np +import soundfile as sf + + +def load_audio(filename: str) -> Tuple[np.ndarray, int]: + data, sample_rate = sf.read( + filename, + always_2d=True, + dtype="float32", + ) + data = data[:, 0] # use only the first channel + + samples = np.ascontiguousarray(data) + return samples, sample_rate + + +def init_model(filename, target_platform="rk3588"): + if not Path(filename).is_file(): + exit(f"{filename} does not exist") + + rknn_lite = RKNNLite(verbose=False) + ret = rknn_lite.load_rknn(path=filename) + if ret != 0: + exit(f"Load model {filename} failed!") + + ret = rknn_lite.init_runtime(core_mask=RKNNLite.NPU_CORE_0) + if ret != 0: + exit(f"Failed to init rknn runtime for {filename}") + return rknn_lite + + +class RKNNModel: + def __init__(self, model: str, target_platform="rk3588"): + self.model = init_model(model) + + def release(self): + self.model.release() + + def __call__(self, x: np.ndarray, h: np.ndarray, c: np.ndarray): + """ + Args: + x: (1, 512), np.float32 + h: (2, 1, 64), np.float32 + c: (2, 1, 64), np.float32 + Returns: + prob: + next_h: + next_c + """ + out, next_h, next_c = self.model.inference(inputs=[x, h, c]) + return out.item(), next_h, next_c + + +def main(): + model = RKNNModel(model="./m.rknn") + for i in range(1): + test(model) + + +def test(model): + print("started") + start = time.time() + samples, sample_rate = load_audio("./lei-jun-test.wav") + assert sample_rate == 16000, sample_rate + + window_size = 512 + + h = np.zeros((2, 1, 64), dtype=np.float32) + c = np.zeros((2, 1, 64), dtype=np.float32) + + threshold = 0.5 + num_windows = samples.shape[0] // window_size + out = [] + for i in range(num_windows): + print(i, num_windows) + this_samples = samples[i * window_size : (i + 1) * window_size] + prob, h, c = model(this_samples[None], h, c) + out.append(prob > threshold) + + min_speech_duration = 0.25 * sample_rate / window_size + min_silence_duration = 0.25 * sample_rate / window_size + + result = [] + last = -1 + for k, f in enumerate(out): + if f >= threshold: + if last == -1: + last = k + elif last != -1: + if k - last > min_speech_duration: + result.append((last, k)) + last = -1 + + if last != -1 and k - last > min_speech_duration: + result.append((last, k)) + + if not result: + print("Empty for ./lei-jun-test.wav") + return + + print(result) + + final = [result[0]] + for r in result[1:]: + f = final[-1] + if r[0] - f[1] < min_silence_duration: + final[-1] = (f[0], r[1]) + else: + final.append(r) + + for f in final: + start = f[0] * window_size / sample_rate + end = f[1] * window_size / sample_rate + print("{:.3f} -- {:.3f}".format(start, end)) + + +if __name__ == "__main__": + main() diff --git a/scripts/silero_vad/v4/test-onnx.py b/scripts/silero_vad/v4/test-onnx.py index 4df09301..8644cf80 100755 --- a/scripts/silero_vad/v4/test-onnx.py +++ b/scripts/silero_vad/v4/test-onnx.py @@ -97,10 +97,13 @@ def main(): h, c = model.get_init_states() window_size = 512 num_windows = samples.shape[0] // window_size + for i in range(num_windows): start = i * window_size end = start + window_size + p, h, c = model(samples[start:end], h, c) + probs.append(p[0].item()) threshold = 0.5 diff --git a/sherpa-onnx/csrc/CMakeLists.txt b/sherpa-onnx/csrc/CMakeLists.txt index 7de2d324..9aa192c0 100644 --- a/sherpa-onnx/csrc/CMakeLists.txt +++ b/sherpa-onnx/csrc/CMakeLists.txt @@ -159,6 +159,7 @@ if(SHERPA_ONNX_ENABLE_RKNN) ./rknn/online-transducer-modified-beam-search-decoder-rknn.cc ./rknn/online-zipformer-ctc-model-rknn.cc ./rknn/online-zipformer-transducer-model-rknn.cc + ./rknn/silero-vad-model-rknn.cc ./rknn/utils.cc ) @@ -468,6 +469,7 @@ if(SHERPA_ONNX_ENABLE_PORTAUDIO AND SHERPA_ONNX_ENABLE_BINARY) microphone.cc ) + add_executable(sherpa-onnx-microphone-offline sherpa-onnx-microphone-offline.cc microphone.cc @@ -498,11 +500,11 @@ if(SHERPA_ONNX_ENABLE_PORTAUDIO AND SHERPA_ONNX_ENABLE_BINARY) ) set(exes - sherpa-onnx-microphone sherpa-onnx-keyword-spotter-microphone + sherpa-onnx-microphone sherpa-onnx-microphone-offline - sherpa-onnx-microphone-offline-speaker-identification sherpa-onnx-microphone-offline-audio-tagging + sherpa-onnx-microphone-offline-speaker-identification sherpa-onnx-vad-microphone sherpa-onnx-vad-microphone-offline-asr sherpa-onnx-vad-with-offline-asr diff --git a/sherpa-onnx/csrc/rknn/silero-vad-model-rknn.cc b/sherpa-onnx/csrc/rknn/silero-vad-model-rknn.cc new file mode 100644 index 00000000..eb1963e9 --- /dev/null +++ b/sherpa-onnx/csrc/rknn/silero-vad-model-rknn.cc @@ -0,0 +1,470 @@ +// sherpa-onnx/csrc/rknn/silero-vad-model-rknn.cc +// +// Copyright (c) 2025 Xiaomi Corporation + +#include "sherpa-onnx/csrc/rknn/silero-vad-model-rknn.h" + +#include +#include +#include + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#if __OHOS__ +#include "rawfile/raw_file_manager.h" +#endif + +#include "sherpa-onnx/csrc/file-utils.h" +#include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/csrc/rknn/macros.h" +#include "sherpa-onnx/csrc/rknn/utils.h" +#include "sherpa-onnx/csrc/text-utils.h" + +namespace sherpa_onnx { + +class SileroVadModelRknn::Impl { + public: + ~Impl() { + auto ret = rknn_destroy(ctx_); + if (ret != RKNN_SUCC) { + SHERPA_ONNX_LOGE("Failed to destroy the context"); + } + } + + explicit Impl(const VadModelConfig &config) + : config_(config), sample_rate_(config.sample_rate) { + auto buf = ReadFile(config.silero_vad.model); + Init(buf.data(), buf.size()); + + if (sample_rate_ != 16000) { + SHERPA_ONNX_LOGE("Expected sample rate 16000. Given: %d", + config.sample_rate); + SHERPA_ONNX_EXIT(-1); + } + + min_silence_samples_ = + sample_rate_ * config_.silero_vad.min_silence_duration; + + min_speech_samples_ = sample_rate_ * config_.silero_vad.min_speech_duration; + } + + template + Impl(Manager *mgr, const VadModelConfig &config) + : config_(config), sample_rate_(config.sample_rate) { + auto buf = ReadFile(mgr, config.silero_vad.model); + Init(buf.data(), buf.size()); + + if (sample_rate_ != 16000) { + SHERPA_ONNX_LOGE("Expected sample rate 16000. Given: %d", + config.sample_rate); + exit(-1); + } + + min_silence_samples_ = + sample_rate_ * config_.silero_vad.min_silence_duration; + + min_speech_samples_ = sample_rate_ * config_.silero_vad.min_speech_duration; + } + + void Reset() { + for (auto &s : states_) { + std::fill(s.begin(), s.end(), 0); + } + + triggered_ = false; + current_sample_ = 0; + temp_start_ = 0; + temp_end_ = 0; + } + + bool IsSpeech(const float *samples, int32_t n) { + if (n != WindowSize()) { + SHERPA_ONNX_LOGE("n: %d != window_size: %d", n, WindowSize()); + SHERPA_ONNX_EXIT(-1); + } + + float prob = Run(samples, n); + + float threshold = config_.silero_vad.threshold; + + current_sample_ += config_.silero_vad.window_size; + + if (prob > threshold && temp_end_ != 0) { + temp_end_ = 0; + } + + if (prob > threshold && temp_start_ == 0) { + // start speaking, but we require that it must satisfy + // min_speech_duration + temp_start_ = current_sample_; + return false; + } + + if (prob > threshold && temp_start_ != 0 && !triggered_) { + if (current_sample_ - temp_start_ < min_speech_samples_) { + return false; + } + + triggered_ = true; + + return true; + } + + if ((prob < threshold) && !triggered_) { + // silence + temp_start_ = 0; + temp_end_ = 0; + return false; + } + + if ((prob > threshold - 0.15) && triggered_) { + // speaking + return true; + } + + if ((prob > threshold) && !triggered_) { + // start speaking + triggered_ = true; + + return true; + } + + if ((prob < threshold) && triggered_) { + // stop to speak + if (temp_end_ == 0) { + temp_end_ = current_sample_; + } + + if (current_sample_ - temp_end_ < min_silence_samples_) { + // continue speaking + return true; + } + // stopped speaking + temp_start_ = 0; + temp_end_ = 0; + triggered_ = false; + return false; + } + + return false; + } + + int32_t WindowShift() const { return config_.silero_vad.window_size; } + + int32_t WindowSize() const { + return config_.silero_vad.window_size + window_overlap_; + } + + int32_t MinSilenceDurationSamples() const { return min_silence_samples_; } + + int32_t MinSpeechDurationSamples() const { return min_speech_samples_; } + + void SetMinSilenceDuration(float s) { + min_silence_samples_ = sample_rate_ * s; + } + + void SetThreshold(float threshold) { + config_.silero_vad.threshold = threshold; + } + + private: + void Init(void *model_data, size_t model_data_length) { + auto ret = rknn_init(&ctx_, model_data, model_data_length, 0, nullptr); + SHERPA_ONNX_RKNN_CHECK(ret, "Failed to init silero vad model '%s'", + config_.silero_vad.model.c_str()); + + if (config_.debug) { + rknn_sdk_version v; + ret = rknn_query(ctx_, RKNN_QUERY_SDK_VERSION, &v, sizeof(v)); + SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get rknn sdk version"); + + SHERPA_ONNX_LOGE("sdk api version: %s, driver version: %s", v.api_version, + v.drv_version); + } + + rknn_input_output_num io_num; + ret = rknn_query(ctx_, RKNN_QUERY_IN_OUT_NUM, &io_num, sizeof(io_num)); + SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get I/O information for the model"); + + if (config_.debug) { + SHERPA_ONNX_LOGE("model: %d inputs, %d outputs", + static_cast(io_num.n_input), + static_cast(io_num.n_output)); + } + + input_attrs_.resize(io_num.n_input); + output_attrs_.resize(io_num.n_output); + + int32_t i = 0; + for (auto &attr : input_attrs_) { + memset(&attr, 0, sizeof(attr)); + attr.index = i; + ret = rknn_query(ctx_, RKNN_QUERY_INPUT_ATTR, &attr, sizeof(attr)); + SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get attr for model input %d", i); + i += 1; + } + + if (config_.debug) { + std::ostringstream os; + std::string sep; + for (auto &attr : input_attrs_) { + os << sep << ToString(attr); + sep = "\n"; + } + SHERPA_ONNX_LOGE("\n----------Model inputs info----------\n%s", + os.str().c_str()); + } + + i = 0; + for (auto &attr : output_attrs_) { + memset(&attr, 0, sizeof(attr)); + attr.index = i; + ret = rknn_query(ctx_, RKNN_QUERY_OUTPUT_ATTR, &attr, sizeof(attr)); + SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get attr for model output %d", i); + i += 1; + } + + if (config_.debug) { + std::ostringstream os; + std::string sep; + for (auto &attr : output_attrs_) { + os << sep << ToString(attr); + sep = "\n"; + } + SHERPA_ONNX_LOGE("\n----------Model outputs info----------\n%s", + os.str().c_str()); + } + + rknn_custom_string custom_string; + ret = rknn_query(ctx_, RKNN_QUERY_CUSTOM_STRING, &custom_string, + sizeof(custom_string)); + SHERPA_ONNX_RKNN_CHECK(ret, "Failed to read custom string from the model"); + if (config_.debug) { + SHERPA_ONNX_LOGE("customs string: %s", custom_string.string); + } + auto meta = Parse(custom_string); + + if (config_.silero_vad.window_size != 512) { + SHERPA_ONNX_LOGE("we require window_size to be 512. Given: %d", + config_.silero_vad.window_size); + SHERPA_ONNX_EXIT(-1); + } + + if (config_.debug) { + for (const auto &p : meta) { + SHERPA_ONNX_LOGE("%s: %s", p.first.c_str(), p.second.c_str()); + } + } + + if (meta.count("model_type") == 0) { + SHERPA_ONNX_LOGE("No model type found in '%s'", + config_.silero_vad.model.c_str()); + SHERPA_ONNX_EXIT(-1); + } + + if (meta.at("model_type") != "silero-vad-v4") { + SHERPA_ONNX_LOGE("Expect model type silero-vad-v4 in '%s', given: '%s'", + config_.silero_vad.model.c_str(), + meta.at("model_type").c_str()); + SHERPA_ONNX_EXIT(-1); + } + + if (meta.count("sample_rate") == 0) { + SHERPA_ONNX_LOGE("No sample_rate found in '%s'", + config_.silero_vad.model.c_str()); + SHERPA_ONNX_EXIT(-1); + } + + if (meta.at("sample_rate") != "16000") { + SHERPA_ONNX_LOGE("Expect sample rate 16000 in '%s', given: '%s'", + config_.silero_vad.model.c_str(), + meta.at("sample_rate").c_str()); + SHERPA_ONNX_EXIT(-1); + } + + if (meta.count("version") == 0) { + SHERPA_ONNX_LOGE("No version found in '%s'", + config_.silero_vad.model.c_str()); + SHERPA_ONNX_EXIT(-1); + } + + if (meta.at("version") != "4") { + SHERPA_ONNX_LOGE("Expect version 4 in '%s', given: '%s'", + config_.silero_vad.model.c_str(), + meta.at("version").c_str()); + SHERPA_ONNX_EXIT(-1); + } + + if (meta.count("h_shape") == 0) { + SHERPA_ONNX_LOGE("No h_shape found in '%s'", + config_.silero_vad.model.c_str()); + SHERPA_ONNX_EXIT(-1); + } + + if (meta.count("c_shape") == 0) { + SHERPA_ONNX_LOGE("No c_shape found in '%s'", + config_.silero_vad.model.c_str()); + SHERPA_ONNX_EXIT(-1); + } + + std::vector h_shape; + std::vector c_shape; + + SplitStringToIntegers(meta.at("h_shape"), ",", false, &h_shape); + SplitStringToIntegers(meta.at("c_shape"), ",", false, &c_shape); + if (h_shape.size() != 3 || c_shape.size() != 3) { + SHERPA_ONNX_LOGE("Incorrect shape for h (%d) or c (%d)", + static_cast(h_shape.size()), + static_cast(c_shape.size())); + SHERPA_ONNX_EXIT(-1); + } + + states_.resize(2); + states_[0].resize(h_shape[0] * h_shape[1] * h_shape[2]); + states_[1].resize(c_shape[0] * c_shape[1] * c_shape[2]); + + Reset(); + } + + float Run(const float *samples, int32_t n) { + std::vector inputs(input_attrs_.size()); + + for (int32_t i = 0; i < static_cast(inputs.size()); ++i) { + auto &input = inputs[i]; + auto &attr = input_attrs_[i]; + input.index = attr.index; + + if (attr.type == RKNN_TENSOR_FLOAT16) { + input.type = RKNN_TENSOR_FLOAT32; + } else if (attr.type == RKNN_TENSOR_INT64) { + input.type = RKNN_TENSOR_INT64; + } else { + SHERPA_ONNX_LOGE("Unsupported tensor type %d, %s", attr.type, + get_type_string(attr.type)); + SHERPA_ONNX_EXIT(-1); + } + + input.fmt = attr.fmt; + if (i == 0) { + input.buf = reinterpret_cast(const_cast(samples)); + input.size = n * sizeof(float); + } else { + input.buf = reinterpret_cast(states_[i - 1].data()); + input.size = states_[i - 1].size() * sizeof(float); + } + } + + std::vector out(output_attrs_[0].n_elems); + + auto &next_states = states_; + + std::vector outputs(output_attrs_.size()); + + for (int32_t i = 0; i < outputs.size(); ++i) { + auto &output = outputs[i]; + auto &attr = output_attrs_[i]; + output.index = attr.index; + output.is_prealloc = 1; + + if (attr.type == RKNN_TENSOR_FLOAT16) { + output.want_float = 1; + } else if (attr.type == RKNN_TENSOR_INT64) { + output.want_float = 0; + } else { + SHERPA_ONNX_LOGE("Unsupported tensor type %d, %s", attr.type, + get_type_string(attr.type)); + SHERPA_ONNX_EXIT(-1); + } + + if (i == 0) { + output.size = out.size() * sizeof(float); + output.buf = reinterpret_cast(out.data()); + } else { + output.size = next_states[i - 1].size() * sizeof(float); + output.buf = reinterpret_cast(next_states[i - 1].data()); + } + } + + auto ret = rknn_inputs_set(ctx_, inputs.size(), inputs.data()); + SHERPA_ONNX_RKNN_CHECK(ret, "Failed to set inputs"); + + ret = rknn_run(ctx_, nullptr); + SHERPA_ONNX_RKNN_CHECK(ret, "Failed to run the model"); + + ret = rknn_outputs_get(ctx_, outputs.size(), outputs.data(), nullptr); + SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get model output"); + + return out[0]; + } + + private: + VadModelConfig config_; + rknn_context ctx_ = 0; + + std::vector input_attrs_; + std::vector output_attrs_; + + std::vector> states_; + + int64_t sample_rate_; + int32_t min_silence_samples_; + int32_t min_speech_samples_; + + bool triggered_ = false; + int32_t current_sample_ = 0; + int32_t temp_start_ = 0; + int32_t temp_end_ = 0; + + int32_t window_overlap_ = 0; +}; + +SileroVadModelRknn::SileroVadModelRknn(const VadModelConfig &config) + : impl_(std::make_unique(config)) {} + +template +SileroVadModelRknn::SileroVadModelRknn(Manager *mgr, + const VadModelConfig &config) + : impl_(std::make_unique(mgr, config)) {} + +SileroVadModelRknn::~SileroVadModelRknn() = default; + +void SileroVadModelRknn::Reset() { return impl_->Reset(); } + +bool SileroVadModelRknn::IsSpeech(const float *samples, int32_t n) { + return impl_->IsSpeech(samples, n); +} + +int32_t SileroVadModelRknn::WindowSize() const { return impl_->WindowSize(); } + +int32_t SileroVadModelRknn::WindowShift() const { return impl_->WindowShift(); } + +int32_t SileroVadModelRknn::MinSilenceDurationSamples() const { + return impl_->MinSilenceDurationSamples(); +} + +int32_t SileroVadModelRknn::MinSpeechDurationSamples() const { + return impl_->MinSpeechDurationSamples(); +} + +void SileroVadModelRknn::SetMinSilenceDuration(float s) { + impl_->SetMinSilenceDuration(s); +} + +void SileroVadModelRknn::SetThreshold(float threshold) { + impl_->SetThreshold(threshold); +} + +#if __ANDROID_API__ >= 9 +template SileroVadModelRknn::SileroVadModelRknn(AAssetManager *mgr, + const VadModelConfig &config); +#endif + +#if __OHOS__ +template SileroVadModelRknn::SileroVadModelRknn(NativeResourceManager *mgr, + const VadModelConfig &config); +#endif + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/rknn/silero-vad-model-rknn.h b/sherpa-onnx/csrc/rknn/silero-vad-model-rknn.h new file mode 100644 index 00000000..a11b34e6 --- /dev/null +++ b/sherpa-onnx/csrc/rknn/silero-vad-model-rknn.h @@ -0,0 +1,53 @@ +// sherpa-onnx/csrc/rknn/silero-vad-model-rknn.h +// +// Copyright (c) 2025 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_RKNN_SILERO_VAD_MODEL_RKNN_H_ +#define SHERPA_ONNX_CSRC_RKNN_SILERO_VAD_MODEL_RKNN_H_ + +#include "rknn_api.h" // NOLINT +#include "sherpa-onnx/csrc/online-model-config.h" +#include "sherpa-onnx/csrc/vad-model.h" + +namespace sherpa_onnx { + +class SileroVadModelRknn : public VadModel { + public: + explicit SileroVadModelRknn(const VadModelConfig &config); + + template + SileroVadModelRknn(Manager *mgr, const VadModelConfig &config); + + ~SileroVadModelRknn() override; + + // reset the internal model states + void Reset() override; + + /** + * @param samples Pointer to a 1-d array containing audio samples. + * Each sample should be normalized to the range [-1, 1]. + * @param n Number of samples. + * + * @return Return true if speech is detected. Return false otherwise. + */ + bool IsSpeech(const float *samples, int32_t n) override; + + // For silero vad V4, it is WindowShift(). + int32_t WindowSize() const override; + + // 512 + int32_t WindowShift() const override; + + int32_t MinSilenceDurationSamples() const override; + int32_t MinSpeechDurationSamples() const override; + + void SetMinSilenceDuration(float s) override; + void SetThreshold(float threshold) override; + + private: + class Impl; + std::unique_ptr impl_; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_RKNN_SILERO_VAD_MODEL_RKNN_H_ diff --git a/sherpa-onnx/csrc/sherpa-onnx-vad-alsa-offline-asr.cc b/sherpa-onnx/csrc/sherpa-onnx-vad-alsa-offline-asr.cc index 0ee6da63..0c3a17b4 100644 --- a/sherpa-onnx/csrc/sherpa-onnx-vad-alsa-offline-asr.cc +++ b/sherpa-onnx/csrc/sherpa-onnx-vad-alsa-offline-asr.cc @@ -129,15 +129,13 @@ as the device_name. exit(-1); } - int32_t chunk = 0.1 * alsa.GetActualSampleRate(); - fprintf(stderr, "Started. Please speak\n"); int32_t window_size = vad_config.silero_vad.window_size; int32_t index = 0; while (!stop) { - const std::vector &samples = alsa.Read(chunk); + const std::vector &samples = alsa.Read(window_size); vad->AcceptWaveform(samples.data(), samples.size()); while (!vad->Empty()) { diff --git a/sherpa-onnx/csrc/vad-model-config.cc b/sherpa-onnx/csrc/vad-model-config.cc index f02ad01c..a250bf7e 100644 --- a/sherpa-onnx/csrc/vad-model-config.cc +++ b/sherpa-onnx/csrc/vad-model-config.cc @@ -7,6 +7,9 @@ #include #include +#include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/csrc/text-utils.h" + namespace sherpa_onnx { void VadModelConfig::Register(ParseOptions *po) { @@ -26,7 +29,27 @@ void VadModelConfig::Register(ParseOptions *po) { "true to display debug information when loading vad models"); } -bool VadModelConfig::Validate() const { return silero_vad.Validate(); } +bool VadModelConfig::Validate() const { + if (provider != "rknn") { + if (!silero_vad.model.empty() && EndsWith(silero_vad.model, ".rknn")) { + SHERPA_ONNX_LOGE( + "--provider is %s, which is not rknn, but you pass an rknn model " + "'%s'", + provider.c_str(), silero_vad.model.c_str()); + return false; + } + } + + if (provider == "rknn") { + if (!silero_vad.model.empty() && EndsWith(silero_vad.model, ".onnx")) { + SHERPA_ONNX_LOGE("--provider is rknn, but you pass an onnx model '%s'", + silero_vad.model.c_str()); + return false; + } + } + + return silero_vad.Validate(); +} std::string VadModelConfig::ToString() const { std::ostringstream os; diff --git a/sherpa-onnx/csrc/vad-model.cc b/sherpa-onnx/csrc/vad-model.cc index 58203bb9..aa6b61a8 100644 --- a/sherpa-onnx/csrc/vad-model.cc +++ b/sherpa-onnx/csrc/vad-model.cc @@ -13,19 +13,27 @@ #include "rawfile/raw_file_manager.h" #endif +#if SHERPA_ONNX_ENABLE_RKNN +#include "sherpa-onnx/csrc/rknn/silero-vad-model-rknn.h" +#endif + #include "sherpa-onnx/csrc/silero-vad-model.h" namespace sherpa_onnx { std::unique_ptr VadModel::Create(const VadModelConfig &config) { - // TODO(fangjun): Support other VAD models. + if (config.provider == "rknn") { + return std::make_unique(config); + } return std::make_unique(config); } template std::unique_ptr VadModel::Create(Manager *mgr, const VadModelConfig &config) { - // TODO(fangjun): Support other VAD models. + if (config.provider == "rknn") { + return std::make_unique(mgr, config); + } return std::make_unique(mgr, config); }