Add C++ runtime for silero_vad with RKNN (#2078)
This commit is contained in:
@@ -100,12 +100,11 @@ int32_t main() {
|
|||||||
|
|
||||||
while (!is_eof) {
|
while (!is_eof) {
|
||||||
if (i + window_size < wave->num_samples) {
|
if (i + window_size < wave->num_samples) {
|
||||||
SherpaOnnxVoiceActivityDetectorAcceptWaveform(vad, wave->samples + i,
|
SherpaOnnxVoiceActivityDetectorAcceptWaveform(vad, wave->samples + i,
|
||||||
window_size);
|
window_size);
|
||||||
}
|
} else {
|
||||||
else {
|
SherpaOnnxVoiceActivityDetectorFlush(vad);
|
||||||
SherpaOnnxVoiceActivityDetectorFlush(vad);
|
is_eof = 1;
|
||||||
is_eof = 1;
|
|
||||||
}
|
}
|
||||||
while (!SherpaOnnxVoiceActivityDetectorEmpty(vad)) {
|
while (!SherpaOnnxVoiceActivityDetectorEmpty(vad)) {
|
||||||
const SherpaOnnxSpeechSegment *segment =
|
const SherpaOnnxSpeechSegment *segment =
|
||||||
|
|||||||
75
scripts/gtcrn/show.py
Executable file
75
scripts/gtcrn/show.py
Executable file
@@ -0,0 +1,75 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2025 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||||
|
|
||||||
|
import onnxruntime
|
||||||
|
import onnx
|
||||||
|
|
||||||
|
"""
|
||||||
|
[key: "model_type"
|
||||||
|
value: "gtcrn"
|
||||||
|
, key: "comment"
|
||||||
|
value: "gtcrn_simple"
|
||||||
|
, key: "version"
|
||||||
|
value: "1"
|
||||||
|
, key: "sample_rate"
|
||||||
|
value: "16000"
|
||||||
|
, key: "model_url"
|
||||||
|
value: "https://github.com/Xiaobin-Rong/gtcrn/blob/main/stream/onnx_models/gtcrn_simple.onnx"
|
||||||
|
, key: "maintainer"
|
||||||
|
value: "k2-fsa"
|
||||||
|
, key: "comment2"
|
||||||
|
value: "Please see also https://github.com/Xiaobin-Rong/gtcrn"
|
||||||
|
, key: "conv_cache_shape"
|
||||||
|
value: "2,1,16,16,33"
|
||||||
|
, key: "tra_cache_shape"
|
||||||
|
value: "2,3,1,1,16"
|
||||||
|
, key: "inter_cache_shape"
|
||||||
|
value: "2,1,33,16"
|
||||||
|
, key: "n_fft"
|
||||||
|
value: "512"
|
||||||
|
, key: "hop_length"
|
||||||
|
value: "256"
|
||||||
|
, key: "window_length"
|
||||||
|
value: "512"
|
||||||
|
, key: "window_type"
|
||||||
|
value: "hann_sqrt"
|
||||||
|
]
|
||||||
|
"""
|
||||||
|
|
||||||
|
"""
|
||||||
|
NodeArg(name='mix', type='tensor(float)', shape=[1, 257, 1, 2])
|
||||||
|
NodeArg(name='conv_cache', type='tensor(float)', shape=[2, 1, 16, 16, 33])
|
||||||
|
NodeArg(name='tra_cache', type='tensor(float)', shape=[2, 3, 1, 1, 16])
|
||||||
|
NodeArg(name='inter_cache', type='tensor(float)', shape=[2, 1, 33, 16])
|
||||||
|
-----
|
||||||
|
NodeArg(name='enh', type='tensor(float)', shape=[1, 257, 1, 2])
|
||||||
|
NodeArg(name='conv_cache_out', type='tensor(float)', shape=[2, 1, 16, 16, 33])
|
||||||
|
NodeArg(name='tra_cache_out', type='tensor(float)', shape=[2, 3, 1, 1, 16])
|
||||||
|
NodeArg(name='inter_cache_out', type='tensor(float)', shape=[2, 1, 33, 16])
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def show(filename):
|
||||||
|
model = onnx.load(filename)
|
||||||
|
print(model.metadata_props)
|
||||||
|
|
||||||
|
session_opts = onnxruntime.SessionOptions()
|
||||||
|
session_opts.log_severity_level = 3
|
||||||
|
sess = onnxruntime.InferenceSession(
|
||||||
|
filename, session_opts, providers=["CPUExecutionProvider"]
|
||||||
|
)
|
||||||
|
for i in sess.get_inputs():
|
||||||
|
print(i)
|
||||||
|
|
||||||
|
print("-----")
|
||||||
|
|
||||||
|
for i in sess.get_outputs():
|
||||||
|
print(i)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
show("./gtcrn_simple.onnx")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -5,15 +5,94 @@ import onnx
|
|||||||
import torch
|
import torch
|
||||||
from onnxsim import simplify
|
from onnxsim import simplify
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
|
||||||
|
def simple_pad(x: Tensor, pad: int) -> Tensor:
|
||||||
|
# _0 = torch.slice(torch.slice(torch.slice(x), 1), 2, 1, torch.add(1, pad))
|
||||||
|
_0 = x[:, :, 1 : 1 + pad]
|
||||||
|
|
||||||
|
left_pad = torch.flip(_0, [-1])
|
||||||
|
# _1 = torch.slice(torch.slice(torch.slice(x), 1), 2, torch.sub(-1, pad), -1)
|
||||||
|
|
||||||
|
_1 = x[:, :, (-1 - pad) : -1]
|
||||||
|
|
||||||
|
right_pad = torch.flip(_1, [-1])
|
||||||
|
_2 = torch.cat([left_pad, x, right_pad], 2)
|
||||||
|
return _2
|
||||||
|
|
||||||
|
|
||||||
|
class MyModule(torch.nn.Module):
|
||||||
|
def __init__(self, m):
|
||||||
|
super().__init__()
|
||||||
|
self.m = m
|
||||||
|
|
||||||
|
def adaptive_normalization_forward(self, spect):
|
||||||
|
m = self.m._model.adaptive_normalization
|
||||||
|
_0 = simple_pad
|
||||||
|
|
||||||
|
# Note(fangjun): rknn uses fp16 by default, whose max value is 65504
|
||||||
|
# so we need to re-write the computation for spect0
|
||||||
|
# spect0 = torch.log1p(torch.mul(spect, 1048576))
|
||||||
|
spect0 = torch.log1p(spect) + 13.86294
|
||||||
|
|
||||||
|
_1 = torch.eq(len(spect0.shape), 2)
|
||||||
|
if _1:
|
||||||
|
_2 = torch.unsqueeze(spect0, 0)
|
||||||
|
spect1 = _2
|
||||||
|
else:
|
||||||
|
spect1 = spect0
|
||||||
|
mean = torch.mean(spect1, [1], True)
|
||||||
|
to_pad = m.to_pad
|
||||||
|
mean0 = _0(
|
||||||
|
mean,
|
||||||
|
to_pad,
|
||||||
|
)
|
||||||
|
filter_ = m.filter_
|
||||||
|
mean1 = torch.conv1d(mean0, filter_)
|
||||||
|
mean_mean = torch.mean(mean1, [-1], True)
|
||||||
|
spect2 = torch.add(spect1, torch.neg(mean_mean))
|
||||||
|
return spect2
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, h: torch.Tensor, c: torch.Tensor):
|
||||||
|
m = self.m._model
|
||||||
|
|
||||||
|
feature_extractor = m.feature_extractor
|
||||||
|
x0 = (feature_extractor).forward(
|
||||||
|
x,
|
||||||
|
)
|
||||||
|
norm = self.adaptive_normalization_forward(x0)
|
||||||
|
x1 = torch.cat([x0, norm], 1)
|
||||||
|
first_layer = m.first_layer
|
||||||
|
x2 = (first_layer).forward(
|
||||||
|
x1,
|
||||||
|
)
|
||||||
|
encoder = m.encoder
|
||||||
|
x3 = (encoder).forward(
|
||||||
|
x2,
|
||||||
|
)
|
||||||
|
decoder = m.decoder
|
||||||
|
x4, h0, c0, = (decoder).forward(
|
||||||
|
x3,
|
||||||
|
h,
|
||||||
|
c,
|
||||||
|
)
|
||||||
|
_0 = torch.mean(torch.squeeze(x4, 1), [1])
|
||||||
|
out = torch.unsqueeze(_0, 1)
|
||||||
|
return (out, h0, c0)
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def main():
|
def main():
|
||||||
m = torch.jit.load("./silero_vad.jit")
|
m = torch.jit.load("./silero_vad.jit")
|
||||||
|
m = MyModule(m)
|
||||||
x = torch.rand((1, 512), dtype=torch.float32)
|
x = torch.rand((1, 512), dtype=torch.float32)
|
||||||
h = torch.rand((2, 1, 64), dtype=torch.float32)
|
h = torch.rand((2, 1, 64), dtype=torch.float32)
|
||||||
c = torch.rand((2, 1, 64), dtype=torch.float32)
|
c = torch.rand((2, 1, 64), dtype=torch.float32)
|
||||||
|
m = torch.jit.script(m)
|
||||||
torch.onnx.export(
|
torch.onnx.export(
|
||||||
m._model,
|
m,
|
||||||
(x, h, c),
|
(x, h, c),
|
||||||
"m.onnx",
|
"m.onnx",
|
||||||
input_names=["x", "h", "c"],
|
input_names=["x", "h", "c"],
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
|
# Copyright 2025 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||||
|
|
||||||
import onnxruntime
|
import onnxruntime
|
||||||
import onnx
|
import onnx
|
||||||
|
|||||||
141
scripts/silero_vad/v4/test-on-rk3588-board.py
Executable file
141
scripts/silero_vad/v4/test-on-rk3588-board.py
Executable file
@@ -0,0 +1,141 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2025 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||||
|
|
||||||
|
# Please run this file on your rk3588 board
|
||||||
|
|
||||||
|
try:
|
||||||
|
from rknnlite.api import RKNNLite
|
||||||
|
except:
|
||||||
|
print("Please run this file on your board (linux + aarch64 + npu)")
|
||||||
|
print("You need to install rknn_toolkit_lite2")
|
||||||
|
print(
|
||||||
|
" from https://github.com/airockchip/rknn-toolkit2/tree/master/rknn-toolkit-lite2/packages"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
"https://github.com/airockchip/rknn-toolkit2/blob/v2.1.0/rknn-toolkit-lite2/packages/rknn_toolkit_lite2-2.1.0-cp310-cp310-linux_aarch64.whl"
|
||||||
|
)
|
||||||
|
print("is known to work")
|
||||||
|
raise
|
||||||
|
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import soundfile as sf
|
||||||
|
|
||||||
|
|
||||||
|
def load_audio(filename: str) -> Tuple[np.ndarray, int]:
|
||||||
|
data, sample_rate = sf.read(
|
||||||
|
filename,
|
||||||
|
always_2d=True,
|
||||||
|
dtype="float32",
|
||||||
|
)
|
||||||
|
data = data[:, 0] # use only the first channel
|
||||||
|
|
||||||
|
samples = np.ascontiguousarray(data)
|
||||||
|
return samples, sample_rate
|
||||||
|
|
||||||
|
|
||||||
|
def init_model(filename, target_platform="rk3588"):
|
||||||
|
if not Path(filename).is_file():
|
||||||
|
exit(f"{filename} does not exist")
|
||||||
|
|
||||||
|
rknn_lite = RKNNLite(verbose=False)
|
||||||
|
ret = rknn_lite.load_rknn(path=filename)
|
||||||
|
if ret != 0:
|
||||||
|
exit(f"Load model {filename} failed!")
|
||||||
|
|
||||||
|
ret = rknn_lite.init_runtime(core_mask=RKNNLite.NPU_CORE_0)
|
||||||
|
if ret != 0:
|
||||||
|
exit(f"Failed to init rknn runtime for {filename}")
|
||||||
|
return rknn_lite
|
||||||
|
|
||||||
|
|
||||||
|
class RKNNModel:
|
||||||
|
def __init__(self, model: str, target_platform="rk3588"):
|
||||||
|
self.model = init_model(model)
|
||||||
|
|
||||||
|
def release(self):
|
||||||
|
self.model.release()
|
||||||
|
|
||||||
|
def __call__(self, x: np.ndarray, h: np.ndarray, c: np.ndarray):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x: (1, 512), np.float32
|
||||||
|
h: (2, 1, 64), np.float32
|
||||||
|
c: (2, 1, 64), np.float32
|
||||||
|
Returns:
|
||||||
|
prob:
|
||||||
|
next_h:
|
||||||
|
next_c
|
||||||
|
"""
|
||||||
|
out, next_h, next_c = self.model.inference(inputs=[x, h, c])
|
||||||
|
return out.item(), next_h, next_c
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
model = RKNNModel(model="./m.rknn")
|
||||||
|
for i in range(1):
|
||||||
|
test(model)
|
||||||
|
|
||||||
|
|
||||||
|
def test(model):
|
||||||
|
print("started")
|
||||||
|
start = time.time()
|
||||||
|
samples, sample_rate = load_audio("./lei-jun-test.wav")
|
||||||
|
assert sample_rate == 16000, sample_rate
|
||||||
|
|
||||||
|
window_size = 512
|
||||||
|
|
||||||
|
h = np.zeros((2, 1, 64), dtype=np.float32)
|
||||||
|
c = np.zeros((2, 1, 64), dtype=np.float32)
|
||||||
|
|
||||||
|
threshold = 0.5
|
||||||
|
num_windows = samples.shape[0] // window_size
|
||||||
|
out = []
|
||||||
|
for i in range(num_windows):
|
||||||
|
print(i, num_windows)
|
||||||
|
this_samples = samples[i * window_size : (i + 1) * window_size]
|
||||||
|
prob, h, c = model(this_samples[None], h, c)
|
||||||
|
out.append(prob > threshold)
|
||||||
|
|
||||||
|
min_speech_duration = 0.25 * sample_rate / window_size
|
||||||
|
min_silence_duration = 0.25 * sample_rate / window_size
|
||||||
|
|
||||||
|
result = []
|
||||||
|
last = -1
|
||||||
|
for k, f in enumerate(out):
|
||||||
|
if f >= threshold:
|
||||||
|
if last == -1:
|
||||||
|
last = k
|
||||||
|
elif last != -1:
|
||||||
|
if k - last > min_speech_duration:
|
||||||
|
result.append((last, k))
|
||||||
|
last = -1
|
||||||
|
|
||||||
|
if last != -1 and k - last > min_speech_duration:
|
||||||
|
result.append((last, k))
|
||||||
|
|
||||||
|
if not result:
|
||||||
|
print("Empty for ./lei-jun-test.wav")
|
||||||
|
return
|
||||||
|
|
||||||
|
print(result)
|
||||||
|
|
||||||
|
final = [result[0]]
|
||||||
|
for r in result[1:]:
|
||||||
|
f = final[-1]
|
||||||
|
if r[0] - f[1] < min_silence_duration:
|
||||||
|
final[-1] = (f[0], r[1])
|
||||||
|
else:
|
||||||
|
final.append(r)
|
||||||
|
|
||||||
|
for f in final:
|
||||||
|
start = f[0] * window_size / sample_rate
|
||||||
|
end = f[1] * window_size / sample_rate
|
||||||
|
print("{:.3f} -- {:.3f}".format(start, end))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -97,10 +97,13 @@ def main():
|
|||||||
h, c = model.get_init_states()
|
h, c = model.get_init_states()
|
||||||
window_size = 512
|
window_size = 512
|
||||||
num_windows = samples.shape[0] // window_size
|
num_windows = samples.shape[0] // window_size
|
||||||
|
|
||||||
for i in range(num_windows):
|
for i in range(num_windows):
|
||||||
start = i * window_size
|
start = i * window_size
|
||||||
end = start + window_size
|
end = start + window_size
|
||||||
|
|
||||||
p, h, c = model(samples[start:end], h, c)
|
p, h, c = model(samples[start:end], h, c)
|
||||||
|
|
||||||
probs.append(p[0].item())
|
probs.append(p[0].item())
|
||||||
|
|
||||||
threshold = 0.5
|
threshold = 0.5
|
||||||
|
|||||||
@@ -159,6 +159,7 @@ if(SHERPA_ONNX_ENABLE_RKNN)
|
|||||||
./rknn/online-transducer-modified-beam-search-decoder-rknn.cc
|
./rknn/online-transducer-modified-beam-search-decoder-rknn.cc
|
||||||
./rknn/online-zipformer-ctc-model-rknn.cc
|
./rknn/online-zipformer-ctc-model-rknn.cc
|
||||||
./rknn/online-zipformer-transducer-model-rknn.cc
|
./rknn/online-zipformer-transducer-model-rknn.cc
|
||||||
|
./rknn/silero-vad-model-rknn.cc
|
||||||
./rknn/utils.cc
|
./rknn/utils.cc
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -468,6 +469,7 @@ if(SHERPA_ONNX_ENABLE_PORTAUDIO AND SHERPA_ONNX_ENABLE_BINARY)
|
|||||||
microphone.cc
|
microphone.cc
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
add_executable(sherpa-onnx-microphone-offline
|
add_executable(sherpa-onnx-microphone-offline
|
||||||
sherpa-onnx-microphone-offline.cc
|
sherpa-onnx-microphone-offline.cc
|
||||||
microphone.cc
|
microphone.cc
|
||||||
@@ -498,11 +500,11 @@ if(SHERPA_ONNX_ENABLE_PORTAUDIO AND SHERPA_ONNX_ENABLE_BINARY)
|
|||||||
)
|
)
|
||||||
|
|
||||||
set(exes
|
set(exes
|
||||||
sherpa-onnx-microphone
|
|
||||||
sherpa-onnx-keyword-spotter-microphone
|
sherpa-onnx-keyword-spotter-microphone
|
||||||
|
sherpa-onnx-microphone
|
||||||
sherpa-onnx-microphone-offline
|
sherpa-onnx-microphone-offline
|
||||||
sherpa-onnx-microphone-offline-speaker-identification
|
|
||||||
sherpa-onnx-microphone-offline-audio-tagging
|
sherpa-onnx-microphone-offline-audio-tagging
|
||||||
|
sherpa-onnx-microphone-offline-speaker-identification
|
||||||
sherpa-onnx-vad-microphone
|
sherpa-onnx-vad-microphone
|
||||||
sherpa-onnx-vad-microphone-offline-asr
|
sherpa-onnx-vad-microphone-offline-asr
|
||||||
sherpa-onnx-vad-with-offline-asr
|
sherpa-onnx-vad-with-offline-asr
|
||||||
|
|||||||
470
sherpa-onnx/csrc/rknn/silero-vad-model-rknn.cc
Normal file
470
sherpa-onnx/csrc/rknn/silero-vad-model-rknn.cc
Normal file
@@ -0,0 +1,470 @@
|
|||||||
|
// sherpa-onnx/csrc/rknn/silero-vad-model-rknn.cc
|
||||||
|
//
|
||||||
|
// Copyright (c) 2025 Xiaomi Corporation
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/rknn/silero-vad-model-rknn.h"
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#if __ANDROID_API__ >= 9
|
||||||
|
#include "android/asset_manager.h"
|
||||||
|
#include "android/asset_manager_jni.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if __OHOS__
|
||||||
|
#include "rawfile/raw_file_manager.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/file-utils.h"
|
||||||
|
#include "sherpa-onnx/csrc/macros.h"
|
||||||
|
#include "sherpa-onnx/csrc/rknn/macros.h"
|
||||||
|
#include "sherpa-onnx/csrc/rknn/utils.h"
|
||||||
|
#include "sherpa-onnx/csrc/text-utils.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
class SileroVadModelRknn::Impl {
|
||||||
|
public:
|
||||||
|
~Impl() {
|
||||||
|
auto ret = rknn_destroy(ctx_);
|
||||||
|
if (ret != RKNN_SUCC) {
|
||||||
|
SHERPA_ONNX_LOGE("Failed to destroy the context");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
explicit Impl(const VadModelConfig &config)
|
||||||
|
: config_(config), sample_rate_(config.sample_rate) {
|
||||||
|
auto buf = ReadFile(config.silero_vad.model);
|
||||||
|
Init(buf.data(), buf.size());
|
||||||
|
|
||||||
|
if (sample_rate_ != 16000) {
|
||||||
|
SHERPA_ONNX_LOGE("Expected sample rate 16000. Given: %d",
|
||||||
|
config.sample_rate);
|
||||||
|
SHERPA_ONNX_EXIT(-1);
|
||||||
|
}
|
||||||
|
|
||||||
|
min_silence_samples_ =
|
||||||
|
sample_rate_ * config_.silero_vad.min_silence_duration;
|
||||||
|
|
||||||
|
min_speech_samples_ = sample_rate_ * config_.silero_vad.min_speech_duration;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Manager>
|
||||||
|
Impl(Manager *mgr, const VadModelConfig &config)
|
||||||
|
: config_(config), sample_rate_(config.sample_rate) {
|
||||||
|
auto buf = ReadFile(mgr, config.silero_vad.model);
|
||||||
|
Init(buf.data(), buf.size());
|
||||||
|
|
||||||
|
if (sample_rate_ != 16000) {
|
||||||
|
SHERPA_ONNX_LOGE("Expected sample rate 16000. Given: %d",
|
||||||
|
config.sample_rate);
|
||||||
|
exit(-1);
|
||||||
|
}
|
||||||
|
|
||||||
|
min_silence_samples_ =
|
||||||
|
sample_rate_ * config_.silero_vad.min_silence_duration;
|
||||||
|
|
||||||
|
min_speech_samples_ = sample_rate_ * config_.silero_vad.min_speech_duration;
|
||||||
|
}
|
||||||
|
|
||||||
|
void Reset() {
|
||||||
|
for (auto &s : states_) {
|
||||||
|
std::fill(s.begin(), s.end(), 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
triggered_ = false;
|
||||||
|
current_sample_ = 0;
|
||||||
|
temp_start_ = 0;
|
||||||
|
temp_end_ = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool IsSpeech(const float *samples, int32_t n) {
|
||||||
|
if (n != WindowSize()) {
|
||||||
|
SHERPA_ONNX_LOGE("n: %d != window_size: %d", n, WindowSize());
|
||||||
|
SHERPA_ONNX_EXIT(-1);
|
||||||
|
}
|
||||||
|
|
||||||
|
float prob = Run(samples, n);
|
||||||
|
|
||||||
|
float threshold = config_.silero_vad.threshold;
|
||||||
|
|
||||||
|
current_sample_ += config_.silero_vad.window_size;
|
||||||
|
|
||||||
|
if (prob > threshold && temp_end_ != 0) {
|
||||||
|
temp_end_ = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (prob > threshold && temp_start_ == 0) {
|
||||||
|
// start speaking, but we require that it must satisfy
|
||||||
|
// min_speech_duration
|
||||||
|
temp_start_ = current_sample_;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (prob > threshold && temp_start_ != 0 && !triggered_) {
|
||||||
|
if (current_sample_ - temp_start_ < min_speech_samples_) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
triggered_ = true;
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
if ((prob < threshold) && !triggered_) {
|
||||||
|
// silence
|
||||||
|
temp_start_ = 0;
|
||||||
|
temp_end_ = 0;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if ((prob > threshold - 0.15) && triggered_) {
|
||||||
|
// speaking
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
if ((prob > threshold) && !triggered_) {
|
||||||
|
// start speaking
|
||||||
|
triggered_ = true;
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
if ((prob < threshold) && triggered_) {
|
||||||
|
// stop to speak
|
||||||
|
if (temp_end_ == 0) {
|
||||||
|
temp_end_ = current_sample_;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (current_sample_ - temp_end_ < min_silence_samples_) {
|
||||||
|
// continue speaking
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
// stopped speaking
|
||||||
|
temp_start_ = 0;
|
||||||
|
temp_end_ = 0;
|
||||||
|
triggered_ = false;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
int32_t WindowShift() const { return config_.silero_vad.window_size; }
|
||||||
|
|
||||||
|
int32_t WindowSize() const {
|
||||||
|
return config_.silero_vad.window_size + window_overlap_;
|
||||||
|
}
|
||||||
|
|
||||||
|
int32_t MinSilenceDurationSamples() const { return min_silence_samples_; }
|
||||||
|
|
||||||
|
int32_t MinSpeechDurationSamples() const { return min_speech_samples_; }
|
||||||
|
|
||||||
|
void SetMinSilenceDuration(float s) {
|
||||||
|
min_silence_samples_ = sample_rate_ * s;
|
||||||
|
}
|
||||||
|
|
||||||
|
void SetThreshold(float threshold) {
|
||||||
|
config_.silero_vad.threshold = threshold;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
void Init(void *model_data, size_t model_data_length) {
|
||||||
|
auto ret = rknn_init(&ctx_, model_data, model_data_length, 0, nullptr);
|
||||||
|
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to init silero vad model '%s'",
|
||||||
|
config_.silero_vad.model.c_str());
|
||||||
|
|
||||||
|
if (config_.debug) {
|
||||||
|
rknn_sdk_version v;
|
||||||
|
ret = rknn_query(ctx_, RKNN_QUERY_SDK_VERSION, &v, sizeof(v));
|
||||||
|
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get rknn sdk version");
|
||||||
|
|
||||||
|
SHERPA_ONNX_LOGE("sdk api version: %s, driver version: %s", v.api_version,
|
||||||
|
v.drv_version);
|
||||||
|
}
|
||||||
|
|
||||||
|
rknn_input_output_num io_num;
|
||||||
|
ret = rknn_query(ctx_, RKNN_QUERY_IN_OUT_NUM, &io_num, sizeof(io_num));
|
||||||
|
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get I/O information for the model");
|
||||||
|
|
||||||
|
if (config_.debug) {
|
||||||
|
SHERPA_ONNX_LOGE("model: %d inputs, %d outputs",
|
||||||
|
static_cast<int32_t>(io_num.n_input),
|
||||||
|
static_cast<int32_t>(io_num.n_output));
|
||||||
|
}
|
||||||
|
|
||||||
|
input_attrs_.resize(io_num.n_input);
|
||||||
|
output_attrs_.resize(io_num.n_output);
|
||||||
|
|
||||||
|
int32_t i = 0;
|
||||||
|
for (auto &attr : input_attrs_) {
|
||||||
|
memset(&attr, 0, sizeof(attr));
|
||||||
|
attr.index = i;
|
||||||
|
ret = rknn_query(ctx_, RKNN_QUERY_INPUT_ATTR, &attr, sizeof(attr));
|
||||||
|
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get attr for model input %d", i);
|
||||||
|
i += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (config_.debug) {
|
||||||
|
std::ostringstream os;
|
||||||
|
std::string sep;
|
||||||
|
for (auto &attr : input_attrs_) {
|
||||||
|
os << sep << ToString(attr);
|
||||||
|
sep = "\n";
|
||||||
|
}
|
||||||
|
SHERPA_ONNX_LOGE("\n----------Model inputs info----------\n%s",
|
||||||
|
os.str().c_str());
|
||||||
|
}
|
||||||
|
|
||||||
|
i = 0;
|
||||||
|
for (auto &attr : output_attrs_) {
|
||||||
|
memset(&attr, 0, sizeof(attr));
|
||||||
|
attr.index = i;
|
||||||
|
ret = rknn_query(ctx_, RKNN_QUERY_OUTPUT_ATTR, &attr, sizeof(attr));
|
||||||
|
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get attr for model output %d", i);
|
||||||
|
i += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (config_.debug) {
|
||||||
|
std::ostringstream os;
|
||||||
|
std::string sep;
|
||||||
|
for (auto &attr : output_attrs_) {
|
||||||
|
os << sep << ToString(attr);
|
||||||
|
sep = "\n";
|
||||||
|
}
|
||||||
|
SHERPA_ONNX_LOGE("\n----------Model outputs info----------\n%s",
|
||||||
|
os.str().c_str());
|
||||||
|
}
|
||||||
|
|
||||||
|
rknn_custom_string custom_string;
|
||||||
|
ret = rknn_query(ctx_, RKNN_QUERY_CUSTOM_STRING, &custom_string,
|
||||||
|
sizeof(custom_string));
|
||||||
|
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to read custom string from the model");
|
||||||
|
if (config_.debug) {
|
||||||
|
SHERPA_ONNX_LOGE("customs string: %s", custom_string.string);
|
||||||
|
}
|
||||||
|
auto meta = Parse(custom_string);
|
||||||
|
|
||||||
|
if (config_.silero_vad.window_size != 512) {
|
||||||
|
SHERPA_ONNX_LOGE("we require window_size to be 512. Given: %d",
|
||||||
|
config_.silero_vad.window_size);
|
||||||
|
SHERPA_ONNX_EXIT(-1);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (config_.debug) {
|
||||||
|
for (const auto &p : meta) {
|
||||||
|
SHERPA_ONNX_LOGE("%s: %s", p.first.c_str(), p.second.c_str());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (meta.count("model_type") == 0) {
|
||||||
|
SHERPA_ONNX_LOGE("No model type found in '%s'",
|
||||||
|
config_.silero_vad.model.c_str());
|
||||||
|
SHERPA_ONNX_EXIT(-1);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (meta.at("model_type") != "silero-vad-v4") {
|
||||||
|
SHERPA_ONNX_LOGE("Expect model type silero-vad-v4 in '%s', given: '%s'",
|
||||||
|
config_.silero_vad.model.c_str(),
|
||||||
|
meta.at("model_type").c_str());
|
||||||
|
SHERPA_ONNX_EXIT(-1);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (meta.count("sample_rate") == 0) {
|
||||||
|
SHERPA_ONNX_LOGE("No sample_rate found in '%s'",
|
||||||
|
config_.silero_vad.model.c_str());
|
||||||
|
SHERPA_ONNX_EXIT(-1);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (meta.at("sample_rate") != "16000") {
|
||||||
|
SHERPA_ONNX_LOGE("Expect sample rate 16000 in '%s', given: '%s'",
|
||||||
|
config_.silero_vad.model.c_str(),
|
||||||
|
meta.at("sample_rate").c_str());
|
||||||
|
SHERPA_ONNX_EXIT(-1);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (meta.count("version") == 0) {
|
||||||
|
SHERPA_ONNX_LOGE("No version found in '%s'",
|
||||||
|
config_.silero_vad.model.c_str());
|
||||||
|
SHERPA_ONNX_EXIT(-1);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (meta.at("version") != "4") {
|
||||||
|
SHERPA_ONNX_LOGE("Expect version 4 in '%s', given: '%s'",
|
||||||
|
config_.silero_vad.model.c_str(),
|
||||||
|
meta.at("version").c_str());
|
||||||
|
SHERPA_ONNX_EXIT(-1);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (meta.count("h_shape") == 0) {
|
||||||
|
SHERPA_ONNX_LOGE("No h_shape found in '%s'",
|
||||||
|
config_.silero_vad.model.c_str());
|
||||||
|
SHERPA_ONNX_EXIT(-1);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (meta.count("c_shape") == 0) {
|
||||||
|
SHERPA_ONNX_LOGE("No c_shape found in '%s'",
|
||||||
|
config_.silero_vad.model.c_str());
|
||||||
|
SHERPA_ONNX_EXIT(-1);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<int64_t> h_shape;
|
||||||
|
std::vector<int64_t> c_shape;
|
||||||
|
|
||||||
|
SplitStringToIntegers(meta.at("h_shape"), ",", false, &h_shape);
|
||||||
|
SplitStringToIntegers(meta.at("c_shape"), ",", false, &c_shape);
|
||||||
|
if (h_shape.size() != 3 || c_shape.size() != 3) {
|
||||||
|
SHERPA_ONNX_LOGE("Incorrect shape for h (%d) or c (%d)",
|
||||||
|
static_cast<int32_t>(h_shape.size()),
|
||||||
|
static_cast<int32_t>(c_shape.size()));
|
||||||
|
SHERPA_ONNX_EXIT(-1);
|
||||||
|
}
|
||||||
|
|
||||||
|
states_.resize(2);
|
||||||
|
states_[0].resize(h_shape[0] * h_shape[1] * h_shape[2]);
|
||||||
|
states_[1].resize(c_shape[0] * c_shape[1] * c_shape[2]);
|
||||||
|
|
||||||
|
Reset();
|
||||||
|
}
|
||||||
|
|
||||||
|
float Run(const float *samples, int32_t n) {
|
||||||
|
std::vector<rknn_input> inputs(input_attrs_.size());
|
||||||
|
|
||||||
|
for (int32_t i = 0; i < static_cast<int32_t>(inputs.size()); ++i) {
|
||||||
|
auto &input = inputs[i];
|
||||||
|
auto &attr = input_attrs_[i];
|
||||||
|
input.index = attr.index;
|
||||||
|
|
||||||
|
if (attr.type == RKNN_TENSOR_FLOAT16) {
|
||||||
|
input.type = RKNN_TENSOR_FLOAT32;
|
||||||
|
} else if (attr.type == RKNN_TENSOR_INT64) {
|
||||||
|
input.type = RKNN_TENSOR_INT64;
|
||||||
|
} else {
|
||||||
|
SHERPA_ONNX_LOGE("Unsupported tensor type %d, %s", attr.type,
|
||||||
|
get_type_string(attr.type));
|
||||||
|
SHERPA_ONNX_EXIT(-1);
|
||||||
|
}
|
||||||
|
|
||||||
|
input.fmt = attr.fmt;
|
||||||
|
if (i == 0) {
|
||||||
|
input.buf = reinterpret_cast<void *>(const_cast<float *>(samples));
|
||||||
|
input.size = n * sizeof(float);
|
||||||
|
} else {
|
||||||
|
input.buf = reinterpret_cast<void *>(states_[i - 1].data());
|
||||||
|
input.size = states_[i - 1].size() * sizeof(float);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<float> out(output_attrs_[0].n_elems);
|
||||||
|
|
||||||
|
auto &next_states = states_;
|
||||||
|
|
||||||
|
std::vector<rknn_output> outputs(output_attrs_.size());
|
||||||
|
|
||||||
|
for (int32_t i = 0; i < outputs.size(); ++i) {
|
||||||
|
auto &output = outputs[i];
|
||||||
|
auto &attr = output_attrs_[i];
|
||||||
|
output.index = attr.index;
|
||||||
|
output.is_prealloc = 1;
|
||||||
|
|
||||||
|
if (attr.type == RKNN_TENSOR_FLOAT16) {
|
||||||
|
output.want_float = 1;
|
||||||
|
} else if (attr.type == RKNN_TENSOR_INT64) {
|
||||||
|
output.want_float = 0;
|
||||||
|
} else {
|
||||||
|
SHERPA_ONNX_LOGE("Unsupported tensor type %d, %s", attr.type,
|
||||||
|
get_type_string(attr.type));
|
||||||
|
SHERPA_ONNX_EXIT(-1);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (i == 0) {
|
||||||
|
output.size = out.size() * sizeof(float);
|
||||||
|
output.buf = reinterpret_cast<void *>(out.data());
|
||||||
|
} else {
|
||||||
|
output.size = next_states[i - 1].size() * sizeof(float);
|
||||||
|
output.buf = reinterpret_cast<void *>(next_states[i - 1].data());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
auto ret = rknn_inputs_set(ctx_, inputs.size(), inputs.data());
|
||||||
|
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to set inputs");
|
||||||
|
|
||||||
|
ret = rknn_run(ctx_, nullptr);
|
||||||
|
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to run the model");
|
||||||
|
|
||||||
|
ret = rknn_outputs_get(ctx_, outputs.size(), outputs.data(), nullptr);
|
||||||
|
SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get model output");
|
||||||
|
|
||||||
|
return out[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
VadModelConfig config_;
|
||||||
|
rknn_context ctx_ = 0;
|
||||||
|
|
||||||
|
std::vector<rknn_tensor_attr> input_attrs_;
|
||||||
|
std::vector<rknn_tensor_attr> output_attrs_;
|
||||||
|
|
||||||
|
std::vector<std::vector<float>> states_;
|
||||||
|
|
||||||
|
int64_t sample_rate_;
|
||||||
|
int32_t min_silence_samples_;
|
||||||
|
int32_t min_speech_samples_;
|
||||||
|
|
||||||
|
bool triggered_ = false;
|
||||||
|
int32_t current_sample_ = 0;
|
||||||
|
int32_t temp_start_ = 0;
|
||||||
|
int32_t temp_end_ = 0;
|
||||||
|
|
||||||
|
int32_t window_overlap_ = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
SileroVadModelRknn::SileroVadModelRknn(const VadModelConfig &config)
|
||||||
|
: impl_(std::make_unique<Impl>(config)) {}
|
||||||
|
|
||||||
|
template <typename Manager>
|
||||||
|
SileroVadModelRknn::SileroVadModelRknn(Manager *mgr,
|
||||||
|
const VadModelConfig &config)
|
||||||
|
: impl_(std::make_unique<Impl>(mgr, config)) {}
|
||||||
|
|
||||||
|
SileroVadModelRknn::~SileroVadModelRknn() = default;
|
||||||
|
|
||||||
|
void SileroVadModelRknn::Reset() { return impl_->Reset(); }
|
||||||
|
|
||||||
|
bool SileroVadModelRknn::IsSpeech(const float *samples, int32_t n) {
|
||||||
|
return impl_->IsSpeech(samples, n);
|
||||||
|
}
|
||||||
|
|
||||||
|
int32_t SileroVadModelRknn::WindowSize() const { return impl_->WindowSize(); }
|
||||||
|
|
||||||
|
int32_t SileroVadModelRknn::WindowShift() const { return impl_->WindowShift(); }
|
||||||
|
|
||||||
|
int32_t SileroVadModelRknn::MinSilenceDurationSamples() const {
|
||||||
|
return impl_->MinSilenceDurationSamples();
|
||||||
|
}
|
||||||
|
|
||||||
|
int32_t SileroVadModelRknn::MinSpeechDurationSamples() const {
|
||||||
|
return impl_->MinSpeechDurationSamples();
|
||||||
|
}
|
||||||
|
|
||||||
|
void SileroVadModelRknn::SetMinSilenceDuration(float s) {
|
||||||
|
impl_->SetMinSilenceDuration(s);
|
||||||
|
}
|
||||||
|
|
||||||
|
void SileroVadModelRknn::SetThreshold(float threshold) {
|
||||||
|
impl_->SetThreshold(threshold);
|
||||||
|
}
|
||||||
|
|
||||||
|
#if __ANDROID_API__ >= 9
|
||||||
|
template SileroVadModelRknn::SileroVadModelRknn(AAssetManager *mgr,
|
||||||
|
const VadModelConfig &config);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if __OHOS__
|
||||||
|
template SileroVadModelRknn::SileroVadModelRknn(NativeResourceManager *mgr,
|
||||||
|
const VadModelConfig &config);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
53
sherpa-onnx/csrc/rknn/silero-vad-model-rknn.h
Normal file
53
sherpa-onnx/csrc/rknn/silero-vad-model-rknn.h
Normal file
@@ -0,0 +1,53 @@
|
|||||||
|
// sherpa-onnx/csrc/rknn/silero-vad-model-rknn.h
|
||||||
|
//
|
||||||
|
// Copyright (c) 2025 Xiaomi Corporation
|
||||||
|
#ifndef SHERPA_ONNX_CSRC_RKNN_SILERO_VAD_MODEL_RKNN_H_
|
||||||
|
#define SHERPA_ONNX_CSRC_RKNN_SILERO_VAD_MODEL_RKNN_H_
|
||||||
|
|
||||||
|
#include "rknn_api.h" // NOLINT
|
||||||
|
#include "sherpa-onnx/csrc/online-model-config.h"
|
||||||
|
#include "sherpa-onnx/csrc/vad-model.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
class SileroVadModelRknn : public VadModel {
|
||||||
|
public:
|
||||||
|
explicit SileroVadModelRknn(const VadModelConfig &config);
|
||||||
|
|
||||||
|
template <typename Manager>
|
||||||
|
SileroVadModelRknn(Manager *mgr, const VadModelConfig &config);
|
||||||
|
|
||||||
|
~SileroVadModelRknn() override;
|
||||||
|
|
||||||
|
// reset the internal model states
|
||||||
|
void Reset() override;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @param samples Pointer to a 1-d array containing audio samples.
|
||||||
|
* Each sample should be normalized to the range [-1, 1].
|
||||||
|
* @param n Number of samples.
|
||||||
|
*
|
||||||
|
* @return Return true if speech is detected. Return false otherwise.
|
||||||
|
*/
|
||||||
|
bool IsSpeech(const float *samples, int32_t n) override;
|
||||||
|
|
||||||
|
// For silero vad V4, it is WindowShift().
|
||||||
|
int32_t WindowSize() const override;
|
||||||
|
|
||||||
|
// 512
|
||||||
|
int32_t WindowShift() const override;
|
||||||
|
|
||||||
|
int32_t MinSilenceDurationSamples() const override;
|
||||||
|
int32_t MinSpeechDurationSamples() const override;
|
||||||
|
|
||||||
|
void SetMinSilenceDuration(float s) override;
|
||||||
|
void SetThreshold(float threshold) override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
class Impl;
|
||||||
|
std::unique_ptr<Impl> impl_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
|
|
||||||
|
#endif // SHERPA_ONNX_CSRC_RKNN_SILERO_VAD_MODEL_RKNN_H_
|
||||||
@@ -129,15 +129,13 @@ as the device_name.
|
|||||||
exit(-1);
|
exit(-1);
|
||||||
}
|
}
|
||||||
|
|
||||||
int32_t chunk = 0.1 * alsa.GetActualSampleRate();
|
|
||||||
|
|
||||||
fprintf(stderr, "Started. Please speak\n");
|
fprintf(stderr, "Started. Please speak\n");
|
||||||
|
|
||||||
int32_t window_size = vad_config.silero_vad.window_size;
|
int32_t window_size = vad_config.silero_vad.window_size;
|
||||||
int32_t index = 0;
|
int32_t index = 0;
|
||||||
|
|
||||||
while (!stop) {
|
while (!stop) {
|
||||||
const std::vector<float> &samples = alsa.Read(chunk);
|
const std::vector<float> &samples = alsa.Read(window_size);
|
||||||
vad->AcceptWaveform(samples.data(), samples.size());
|
vad->AcceptWaveform(samples.data(), samples.size());
|
||||||
|
|
||||||
while (!vad->Empty()) {
|
while (!vad->Empty()) {
|
||||||
|
|||||||
@@ -7,6 +7,9 @@
|
|||||||
#include <sstream>
|
#include <sstream>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/macros.h"
|
||||||
|
#include "sherpa-onnx/csrc/text-utils.h"
|
||||||
|
|
||||||
namespace sherpa_onnx {
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
void VadModelConfig::Register(ParseOptions *po) {
|
void VadModelConfig::Register(ParseOptions *po) {
|
||||||
@@ -26,7 +29,27 @@ void VadModelConfig::Register(ParseOptions *po) {
|
|||||||
"true to display debug information when loading vad models");
|
"true to display debug information when loading vad models");
|
||||||
}
|
}
|
||||||
|
|
||||||
bool VadModelConfig::Validate() const { return silero_vad.Validate(); }
|
bool VadModelConfig::Validate() const {
|
||||||
|
if (provider != "rknn") {
|
||||||
|
if (!silero_vad.model.empty() && EndsWith(silero_vad.model, ".rknn")) {
|
||||||
|
SHERPA_ONNX_LOGE(
|
||||||
|
"--provider is %s, which is not rknn, but you pass an rknn model "
|
||||||
|
"'%s'",
|
||||||
|
provider.c_str(), silero_vad.model.c_str());
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (provider == "rknn") {
|
||||||
|
if (!silero_vad.model.empty() && EndsWith(silero_vad.model, ".onnx")) {
|
||||||
|
SHERPA_ONNX_LOGE("--provider is rknn, but you pass an onnx model '%s'",
|
||||||
|
silero_vad.model.c_str());
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return silero_vad.Validate();
|
||||||
|
}
|
||||||
|
|
||||||
std::string VadModelConfig::ToString() const {
|
std::string VadModelConfig::ToString() const {
|
||||||
std::ostringstream os;
|
std::ostringstream os;
|
||||||
|
|||||||
@@ -13,19 +13,27 @@
|
|||||||
#include "rawfile/raw_file_manager.h"
|
#include "rawfile/raw_file_manager.h"
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#if SHERPA_ONNX_ENABLE_RKNN
|
||||||
|
#include "sherpa-onnx/csrc/rknn/silero-vad-model-rknn.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
#include "sherpa-onnx/csrc/silero-vad-model.h"
|
#include "sherpa-onnx/csrc/silero-vad-model.h"
|
||||||
|
|
||||||
namespace sherpa_onnx {
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
std::unique_ptr<VadModel> VadModel::Create(const VadModelConfig &config) {
|
std::unique_ptr<VadModel> VadModel::Create(const VadModelConfig &config) {
|
||||||
// TODO(fangjun): Support other VAD models.
|
if (config.provider == "rknn") {
|
||||||
|
return std::make_unique<SileroVadModelRknn>(config);
|
||||||
|
}
|
||||||
return std::make_unique<SileroVadModel>(config);
|
return std::make_unique<SileroVadModel>(config);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename Manager>
|
template <typename Manager>
|
||||||
std::unique_ptr<VadModel> VadModel::Create(Manager *mgr,
|
std::unique_ptr<VadModel> VadModel::Create(Manager *mgr,
|
||||||
const VadModelConfig &config) {
|
const VadModelConfig &config) {
|
||||||
// TODO(fangjun): Support other VAD models.
|
if (config.provider == "rknn") {
|
||||||
|
return std::make_unique<SileroVadModelRknn>(mgr, config);
|
||||||
|
}
|
||||||
return std::make_unique<SileroVadModel>(mgr, config);
|
return std::make_unique<SileroVadModel>(mgr, config);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user