This repository has been archived on 2025-08-26. You can view files and clone it, but cannot push or open issues or pull requests.
Files
enginex-mr_series-sherpa-onnx/sherpa-onnx/python/tests/test_speaker_recognition.py

218 lines
7.0 KiB
Python
Executable File

# sherpa-onnx/python/tests/test_speaker_recognition.py
#
# Copyright (c) 2024 Xiaomi Corporation
#
# To run this single test, use
#
# ctest --verbose -R test_speaker_recognition_py
import unittest
import wave
from collections import defaultdict
from pathlib import Path
from typing import Tuple
import numpy as np
import sherpa_onnx
d = "/tmp/sr-models"
def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]:
"""
Args:
wave_filename:
Path to a wave file. It should be single channel and each sample should
be 16-bit. Its sample rate does not need to be 16kHz.
Returns:
Return a tuple containing:
- A 1-D array of dtype np.float32 containing the samples, which are
normalized to the range [-1, 1].
- sample rate of the wave file
"""
with wave.open(wave_filename) as f:
assert f.getnchannels() == 1, f.getnchannels()
assert f.getsampwidth() == 2, f.getsampwidth() # it is in bytes
num_samples = f.getnframes()
samples = f.readframes(num_samples)
samples_int16 = np.frombuffer(samples, dtype=np.int16)
samples_float32 = samples_int16.astype(np.float32)
samples_float32 = samples_float32 / 32768
return samples_float32, f.getframerate()
def load_speaker_embedding_model(model_filename):
config = sherpa_onnx.SpeakerEmbeddingExtractorConfig(
model=model_filename,
num_threads=1,
debug=True,
provider="cpu",
)
if not config.validate():
raise ValueError(f"Invalid config. {config}")
extractor = sherpa_onnx.SpeakerEmbeddingExtractor(config)
return extractor
def test_zh_models(model_filename: str):
model_filename = str(model_filename)
if "en" in model_filename:
print(f"skip {model_filename}")
return
extractor = load_speaker_embedding_model(model_filename)
filenames = [
"leijun-sr-1",
"leijun-sr-2",
"fangjun-sr-1",
"fangjun-sr-2",
"fangjun-sr-3",
]
tmp = defaultdict(list)
for filename in filenames:
print(filename)
name = filename.split("-", maxsplit=1)[0]
data, sample_rate = read_wave(f"/tmp/sr-models/sr-data/enroll/{filename}.wav")
stream = extractor.create_stream()
stream.accept_waveform(sample_rate=sample_rate, waveform=data)
stream.input_finished()
assert extractor.is_ready(stream)
embedding = extractor.compute(stream)
embedding = np.array(embedding)
tmp[name].append(embedding)
manager = sherpa_onnx.SpeakerEmbeddingManager(extractor.dim)
for name, embedding_list in tmp.items():
print(name, len(embedding_list))
embedding = sum(embedding_list) / len(embedding_list)
status = manager.add(name, embedding)
if not status:
raise RuntimeError(f"Failed to register speaker {name}")
filenames = [
"leijun-test-sr-1",
"leijun-test-sr-2",
"leijun-test-sr-3",
"fangjun-test-sr-1",
"fangjun-test-sr-2",
]
for filename in filenames:
name = filename.split("-", maxsplit=1)[0]
data, sample_rate = read_wave(f"/tmp/sr-models/sr-data/test/{filename}.wav")
stream = extractor.create_stream()
stream.accept_waveform(sample_rate=sample_rate, waveform=data)
stream.input_finished()
assert extractor.is_ready(stream)
embedding = extractor.compute(stream)
embedding = np.array(embedding)
status = manager.verify(name, embedding, threshold=0.5)
if not status:
raise RuntimeError(f"Failed to verify {name} with wave {filename}.wav")
ans = manager.search(embedding, threshold=0.5)
assert ans == name, (name, ans)
def test_en_and_zh_models(model_filename: str):
model_filename = str(model_filename)
extractor = load_speaker_embedding_model(model_filename)
manager = sherpa_onnx.SpeakerEmbeddingManager(extractor.dim)
filenames = [
"speaker1_a_cn_16k",
"speaker2_a_cn_16k",
"speaker1_a_en_16k",
"speaker2_a_en_16k",
]
is_en = "en" in model_filename
for filename in filenames:
if is_en and "cn" in filename:
continue
if not is_en and "en" in filename:
continue
name = filename.rsplit("_", maxsplit=1)[0]
data, sample_rate = read_wave(
f"/tmp/sr-models/sr-data/test/3d-speaker/{filename}.wav"
)
stream = extractor.create_stream()
stream.accept_waveform(sample_rate=sample_rate, waveform=data)
stream.input_finished()
assert extractor.is_ready(stream)
embedding = extractor.compute(stream)
embedding = np.array(embedding)
status = manager.add(name, embedding)
if not status:
raise RuntimeError(f"Failed to register speaker {name}")
filenames = [
"speaker1_b_cn_16k",
"speaker1_b_en_16k",
]
for filename in filenames:
if is_en and "cn" in filename:
continue
if not is_en and "en" in filename:
continue
print(filename)
name = filename.rsplit("_", maxsplit=1)[0]
name = name.replace("b_cn", "a_cn")
name = name.replace("b_en", "a_en")
print(name)
data, sample_rate = read_wave(
f"/tmp/sr-models/sr-data/test/3d-speaker/{filename}.wav"
)
stream = extractor.create_stream()
stream.accept_waveform(sample_rate=sample_rate, waveform=data)
stream.input_finished()
assert extractor.is_ready(stream)
embedding = extractor.compute(stream)
embedding = np.array(embedding)
status = manager.verify(name, embedding, threshold=0.5)
if not status:
raise RuntimeError(
f"Failed to verify {name} with wave {filename}.wav. model: {model_filename}"
)
ans = manager.search(embedding, threshold=0.5)
assert ans == name, (name, ans)
class TestSpeakerRecognition(unittest.TestCase):
def test_wespeaker_models(self):
model_dir = Path(d) / "wespeaker"
if not model_dir.is_dir():
print(f"{model_dir} does not exist - skip it")
return
for filename in model_dir.glob("*.onnx"):
print(filename)
test_zh_models(filename)
test_en_and_zh_models(filename)
def test_3dpeaker_models(self):
model_dir = Path(d) / "3dspeaker"
if not model_dir.is_dir():
print(f"{model_dir} does not exist - skip it")
return
for filename in model_dir.glob("*.onnx"):
print(filename)
test_en_and_zh_models(filename)
def test_nemo_models(self):
model_dir = Path(d) / "nemo"
if not model_dir.is_dir():
print(f"{model_dir} does not exist - skip it")
return
for filename in model_dir.glob("*.onnx"):
print(filename)
test_en_and_zh_models(filename)
if __name__ == "__main__":
unittest.main()