218 lines
7.0 KiB
Python
Executable File
218 lines
7.0 KiB
Python
Executable File
# sherpa-onnx/python/tests/test_speaker_recognition.py
|
|
#
|
|
# Copyright (c) 2024 Xiaomi Corporation
|
|
#
|
|
# To run this single test, use
|
|
#
|
|
# ctest --verbose -R test_speaker_recognition_py
|
|
|
|
import unittest
|
|
import wave
|
|
from collections import defaultdict
|
|
from pathlib import Path
|
|
from typing import Tuple
|
|
|
|
import numpy as np
|
|
import sherpa_onnx
|
|
|
|
d = "/tmp/sr-models"
|
|
|
|
|
|
def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]:
|
|
"""
|
|
Args:
|
|
wave_filename:
|
|
Path to a wave file. It should be single channel and each sample should
|
|
be 16-bit. Its sample rate does not need to be 16kHz.
|
|
Returns:
|
|
Return a tuple containing:
|
|
- A 1-D array of dtype np.float32 containing the samples, which are
|
|
normalized to the range [-1, 1].
|
|
- sample rate of the wave file
|
|
"""
|
|
|
|
with wave.open(wave_filename) as f:
|
|
assert f.getnchannels() == 1, f.getnchannels()
|
|
assert f.getsampwidth() == 2, f.getsampwidth() # it is in bytes
|
|
num_samples = f.getnframes()
|
|
samples = f.readframes(num_samples)
|
|
samples_int16 = np.frombuffer(samples, dtype=np.int16)
|
|
samples_float32 = samples_int16.astype(np.float32)
|
|
|
|
samples_float32 = samples_float32 / 32768
|
|
return samples_float32, f.getframerate()
|
|
|
|
|
|
def load_speaker_embedding_model(model_filename):
|
|
config = sherpa_onnx.SpeakerEmbeddingExtractorConfig(
|
|
model=model_filename,
|
|
num_threads=1,
|
|
debug=True,
|
|
provider="cpu",
|
|
)
|
|
if not config.validate():
|
|
raise ValueError(f"Invalid config. {config}")
|
|
extractor = sherpa_onnx.SpeakerEmbeddingExtractor(config)
|
|
return extractor
|
|
|
|
|
|
def test_zh_models(model_filename: str):
|
|
model_filename = str(model_filename)
|
|
if "en" in model_filename:
|
|
print(f"skip {model_filename}")
|
|
return
|
|
extractor = load_speaker_embedding_model(model_filename)
|
|
filenames = [
|
|
"leijun-sr-1",
|
|
"leijun-sr-2",
|
|
"fangjun-sr-1",
|
|
"fangjun-sr-2",
|
|
"fangjun-sr-3",
|
|
]
|
|
tmp = defaultdict(list)
|
|
for filename in filenames:
|
|
print(filename)
|
|
name = filename.split("-", maxsplit=1)[0]
|
|
data, sample_rate = read_wave(f"/tmp/sr-models/sr-data/enroll/{filename}.wav")
|
|
stream = extractor.create_stream()
|
|
stream.accept_waveform(sample_rate=sample_rate, waveform=data)
|
|
stream.input_finished()
|
|
assert extractor.is_ready(stream)
|
|
embedding = extractor.compute(stream)
|
|
embedding = np.array(embedding)
|
|
tmp[name].append(embedding)
|
|
|
|
manager = sherpa_onnx.SpeakerEmbeddingManager(extractor.dim)
|
|
for name, embedding_list in tmp.items():
|
|
print(name, len(embedding_list))
|
|
embedding = sum(embedding_list) / len(embedding_list)
|
|
status = manager.add(name, embedding)
|
|
if not status:
|
|
raise RuntimeError(f"Failed to register speaker {name}")
|
|
|
|
filenames = [
|
|
"leijun-test-sr-1",
|
|
"leijun-test-sr-2",
|
|
"leijun-test-sr-3",
|
|
"fangjun-test-sr-1",
|
|
"fangjun-test-sr-2",
|
|
]
|
|
for filename in filenames:
|
|
name = filename.split("-", maxsplit=1)[0]
|
|
data, sample_rate = read_wave(f"/tmp/sr-models/sr-data/test/{filename}.wav")
|
|
stream = extractor.create_stream()
|
|
stream.accept_waveform(sample_rate=sample_rate, waveform=data)
|
|
stream.input_finished()
|
|
assert extractor.is_ready(stream)
|
|
embedding = extractor.compute(stream)
|
|
embedding = np.array(embedding)
|
|
status = manager.verify(name, embedding, threshold=0.5)
|
|
if not status:
|
|
raise RuntimeError(f"Failed to verify {name} with wave {filename}.wav")
|
|
|
|
ans = manager.search(embedding, threshold=0.5)
|
|
assert ans == name, (name, ans)
|
|
|
|
|
|
def test_en_and_zh_models(model_filename: str):
|
|
model_filename = str(model_filename)
|
|
extractor = load_speaker_embedding_model(model_filename)
|
|
manager = sherpa_onnx.SpeakerEmbeddingManager(extractor.dim)
|
|
|
|
filenames = [
|
|
"speaker1_a_cn_16k",
|
|
"speaker2_a_cn_16k",
|
|
"speaker1_a_en_16k",
|
|
"speaker2_a_en_16k",
|
|
]
|
|
is_en = "en" in model_filename
|
|
for filename in filenames:
|
|
if is_en and "cn" in filename:
|
|
continue
|
|
|
|
if not is_en and "en" in filename:
|
|
continue
|
|
|
|
name = filename.rsplit("_", maxsplit=1)[0]
|
|
data, sample_rate = read_wave(
|
|
f"/tmp/sr-models/sr-data/test/3d-speaker/{filename}.wav"
|
|
)
|
|
stream = extractor.create_stream()
|
|
stream.accept_waveform(sample_rate=sample_rate, waveform=data)
|
|
stream.input_finished()
|
|
assert extractor.is_ready(stream)
|
|
embedding = extractor.compute(stream)
|
|
embedding = np.array(embedding)
|
|
|
|
status = manager.add(name, embedding)
|
|
if not status:
|
|
raise RuntimeError(f"Failed to register speaker {name}")
|
|
|
|
filenames = [
|
|
"speaker1_b_cn_16k",
|
|
"speaker1_b_en_16k",
|
|
]
|
|
for filename in filenames:
|
|
if is_en and "cn" in filename:
|
|
continue
|
|
|
|
if not is_en and "en" in filename:
|
|
continue
|
|
print(filename)
|
|
name = filename.rsplit("_", maxsplit=1)[0]
|
|
name = name.replace("b_cn", "a_cn")
|
|
name = name.replace("b_en", "a_en")
|
|
print(name)
|
|
|
|
data, sample_rate = read_wave(
|
|
f"/tmp/sr-models/sr-data/test/3d-speaker/{filename}.wav"
|
|
)
|
|
stream = extractor.create_stream()
|
|
stream.accept_waveform(sample_rate=sample_rate, waveform=data)
|
|
stream.input_finished()
|
|
assert extractor.is_ready(stream)
|
|
embedding = extractor.compute(stream)
|
|
embedding = np.array(embedding)
|
|
status = manager.verify(name, embedding, threshold=0.5)
|
|
if not status:
|
|
raise RuntimeError(
|
|
f"Failed to verify {name} with wave {filename}.wav. model: {model_filename}"
|
|
)
|
|
|
|
ans = manager.search(embedding, threshold=0.5)
|
|
assert ans == name, (name, ans)
|
|
|
|
|
|
class TestSpeakerRecognition(unittest.TestCase):
|
|
def test_wespeaker_models(self):
|
|
model_dir = Path(d) / "wespeaker"
|
|
if not model_dir.is_dir():
|
|
print(f"{model_dir} does not exist - skip it")
|
|
return
|
|
for filename in model_dir.glob("*.onnx"):
|
|
print(filename)
|
|
test_zh_models(filename)
|
|
test_en_and_zh_models(filename)
|
|
|
|
def test_3dpeaker_models(self):
|
|
model_dir = Path(d) / "3dspeaker"
|
|
if not model_dir.is_dir():
|
|
print(f"{model_dir} does not exist - skip it")
|
|
return
|
|
for filename in model_dir.glob("*.onnx"):
|
|
print(filename)
|
|
test_en_and_zh_models(filename)
|
|
|
|
def test_nemo_models(self):
|
|
model_dir = Path(d) / "nemo"
|
|
if not model_dir.is_dir():
|
|
print(f"{model_dir} does not exist - skip it")
|
|
return
|
|
for filename in model_dir.glob("*.onnx"):
|
|
print(filename)
|
|
test_en_and_zh_models(filename)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|