Export silero_vad v4 to RKNN (#2067)
This commit is contained in:
147
scripts/silero_vad/v4/test-onnx.py
Executable file
147
scripts/silero_vad/v4/test-onnx.py
Executable file
@@ -0,0 +1,147 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2025 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
import onnxruntime as ort
|
||||
import argparse
|
||||
import soundfile as sf
|
||||
from typing import Tuple
|
||||
import numpy as np
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the onnx model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--wav",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the input wav",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
class OnnxModel:
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
):
|
||||
session_opts = ort.SessionOptions()
|
||||
session_opts.inter_op_num_threads = 1
|
||||
session_opts.intra_op_num_threads = 1
|
||||
self.model = ort.InferenceSession(
|
||||
model,
|
||||
sess_options=session_opts,
|
||||
providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
def get_init_states(self):
|
||||
h = np.zeros((2, 1, 64), dtype=np.float32)
|
||||
c = np.zeros((2, 1, 64), dtype=np.float32)
|
||||
return h, c
|
||||
|
||||
def __call__(self, x, h, c):
|
||||
"""
|
||||
Args:
|
||||
x: (1, 512)
|
||||
h: (2, 1, 64)
|
||||
c: (2, 1, 64)
|
||||
Returns:
|
||||
prob: (1, 1)
|
||||
next_h: (2, 1, 64)
|
||||
next_c: (2, 1, 64)
|
||||
"""
|
||||
x = x[None]
|
||||
out, next_h, next_c = self.model.run(
|
||||
[
|
||||
self.model.get_outputs()[0].name,
|
||||
self.model.get_outputs()[1].name,
|
||||
self.model.get_outputs()[2].name,
|
||||
],
|
||||
{
|
||||
self.model.get_inputs()[0].name: x,
|
||||
self.model.get_inputs()[1].name: h,
|
||||
self.model.get_inputs()[2].name: c,
|
||||
},
|
||||
)
|
||||
return out, next_h, next_c
|
||||
|
||||
|
||||
def load_audio(filename: str) -> Tuple[np.ndarray, int]:
|
||||
data, sample_rate = sf.read(
|
||||
filename,
|
||||
always_2d=True,
|
||||
dtype="float32",
|
||||
)
|
||||
data = data[:, 0] # use only the first channel
|
||||
samples = np.ascontiguousarray(data)
|
||||
return samples, sample_rate
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
|
||||
samples, sample_rate = load_audio(args.wav)
|
||||
if sample_rate != 16000:
|
||||
import librosa
|
||||
|
||||
samples = librosa.resample(samples, orig_sr=sample_rate, target_sr=16000)
|
||||
sample_rate = 16000
|
||||
|
||||
model = OnnxModel(args.model)
|
||||
probs = []
|
||||
h, c = model.get_init_states()
|
||||
window_size = 512
|
||||
num_windows = samples.shape[0] // window_size
|
||||
for i in range(num_windows):
|
||||
start = i * window_size
|
||||
end = start + window_size
|
||||
p, h, c = model(samples[start:end], h, c)
|
||||
probs.append(p[0].item())
|
||||
|
||||
threshold = 0.5
|
||||
out = np.array(probs) > threshold
|
||||
out = out.tolist()
|
||||
min_speech_duration = 0.25 * sample_rate / window_size
|
||||
min_silence_duration = 0.25 * sample_rate / window_size
|
||||
|
||||
result = []
|
||||
last = -1
|
||||
for k, f in enumerate(out):
|
||||
if f >= threshold:
|
||||
if last == -1:
|
||||
last = k
|
||||
elif last != -1:
|
||||
if k - last > min_speech_duration:
|
||||
result.append((last, k))
|
||||
last = -1
|
||||
|
||||
if last != -1 and k - last > min_speech_duration:
|
||||
result.append((last, k))
|
||||
|
||||
if not result:
|
||||
print(f"Empty for {args.wav}")
|
||||
return
|
||||
|
||||
print(result)
|
||||
|
||||
final = [result[0]]
|
||||
for r in result[1:]:
|
||||
f = final[-1]
|
||||
if r[0] - f[1] < min_silence_duration:
|
||||
final[-1] = (f[0], r[1])
|
||||
else:
|
||||
final.append(r)
|
||||
|
||||
for f in final:
|
||||
start = f[0] * window_size / sample_rate
|
||||
end = f[1] * window_size / sample_rate
|
||||
print("{:.3f} -- {:.3f}".format(start, end))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user