Add C++ runtime for models from 3d-speaker (#523)
This commit is contained in:
@@ -23,6 +23,7 @@ set(py_test_files
|
||||
test_offline_recognizer.py
|
||||
test_online_recognizer.py
|
||||
test_online_transducer_model_config.py
|
||||
test_speaker_recognition.py
|
||||
test_text2token.py
|
||||
)
|
||||
|
||||
|
||||
0
sherpa-onnx/python/tests/test_feature_extractor_config.py
Normal file → Executable file
0
sherpa-onnx/python/tests/test_feature_extractor_config.py
Normal file → Executable file
0
sherpa-onnx/python/tests/test_online_transducer_model_config.py
Normal file → Executable file
0
sherpa-onnx/python/tests/test_online_transducer_model_config.py
Normal file → Executable file
194
sherpa-onnx/python/tests/test_speaker_recognition.py
Executable file
194
sherpa-onnx/python/tests/test_speaker_recognition.py
Executable file
@@ -0,0 +1,194 @@
|
||||
# sherpa-onnx/python/tests/test_speaker_recognition.py
|
||||
#
|
||||
# Copyright (c) 2024 Xiaomi Corporation
|
||||
#
|
||||
# To run this single test, use
|
||||
#
|
||||
# ctest --verbose -R test_speaker_recognition_py
|
||||
|
||||
import unittest
|
||||
import wave
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
import sherpa_onnx
|
||||
|
||||
d = "/tmp/sr-models"
|
||||
|
||||
|
||||
def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]:
|
||||
"""
|
||||
Args:
|
||||
wave_filename:
|
||||
Path to a wave file. It should be single channel and each sample should
|
||||
be 16-bit. Its sample rate does not need to be 16kHz.
|
||||
Returns:
|
||||
Return a tuple containing:
|
||||
- A 1-D array of dtype np.float32 containing the samples, which are
|
||||
normalized to the range [-1, 1].
|
||||
- sample rate of the wave file
|
||||
"""
|
||||
|
||||
with wave.open(wave_filename) as f:
|
||||
assert f.getnchannels() == 1, f.getnchannels()
|
||||
assert f.getsampwidth() == 2, f.getsampwidth() # it is in bytes
|
||||
num_samples = f.getnframes()
|
||||
samples = f.readframes(num_samples)
|
||||
samples_int16 = np.frombuffer(samples, dtype=np.int16)
|
||||
samples_float32 = samples_int16.astype(np.float32)
|
||||
|
||||
samples_float32 = samples_float32 / 32768
|
||||
return samples_float32, f.getframerate()
|
||||
|
||||
|
||||
def load_speaker_embedding_model(model_filename):
|
||||
config = sherpa_onnx.SpeakerEmbeddingExtractorConfig(
|
||||
model=model_filename,
|
||||
num_threads=1,
|
||||
debug=True,
|
||||
provider="cpu",
|
||||
)
|
||||
if not config.validate():
|
||||
raise ValueError(f"Invalid config. {config}")
|
||||
extractor = sherpa_onnx.SpeakerEmbeddingExtractor(config)
|
||||
return extractor
|
||||
|
||||
|
||||
def test_wespeaker_model(model_filename: str):
|
||||
model_filename = str(model_filename)
|
||||
if "en" in model_filename:
|
||||
print(f"skip {model_filename}")
|
||||
return
|
||||
extractor = load_speaker_embedding_model(model_filename)
|
||||
filenames = [
|
||||
"leijun-sr-1",
|
||||
"leijun-sr-2",
|
||||
"fangjun-sr-1",
|
||||
"fangjun-sr-2",
|
||||
"fangjun-sr-3",
|
||||
]
|
||||
tmp = defaultdict(list)
|
||||
for filename in filenames:
|
||||
print(filename)
|
||||
name = filename.split("-", maxsplit=1)[0]
|
||||
data, sample_rate = read_wave(f"/tmp/sr-models/sr-data/enroll/{filename}.wav")
|
||||
stream = extractor.create_stream()
|
||||
stream.accept_waveform(sample_rate=sample_rate, waveform=data)
|
||||
stream.input_finished()
|
||||
assert extractor.is_ready(stream)
|
||||
embedding = extractor.compute(stream)
|
||||
embedding = np.array(embedding)
|
||||
tmp[name].append(embedding)
|
||||
|
||||
manager = sherpa_onnx.SpeakerEmbeddingManager(extractor.dim)
|
||||
for name, embedding_list in tmp.items():
|
||||
print(name, len(embedding_list))
|
||||
embedding = sum(embedding_list) / len(embedding_list)
|
||||
status = manager.add(name, embedding)
|
||||
if not status:
|
||||
raise RuntimeError(f"Failed to register speaker {name}")
|
||||
|
||||
filenames = [
|
||||
"leijun-test-sr-1",
|
||||
"leijun-test-sr-2",
|
||||
"leijun-test-sr-3",
|
||||
"fangjun-test-sr-1",
|
||||
"fangjun-test-sr-2",
|
||||
]
|
||||
for filename in filenames:
|
||||
name = filename.split("-", maxsplit=1)[0]
|
||||
data, sample_rate = read_wave(f"/tmp/sr-models/sr-data/test/{filename}.wav")
|
||||
stream = extractor.create_stream()
|
||||
stream.accept_waveform(sample_rate=sample_rate, waveform=data)
|
||||
stream.input_finished()
|
||||
assert extractor.is_ready(stream)
|
||||
embedding = extractor.compute(stream)
|
||||
embedding = np.array(embedding)
|
||||
status = manager.verify(name, embedding, threshold=0.5)
|
||||
if not status:
|
||||
raise RuntimeError(f"Failed to verify {name} with wave {filename}.wav")
|
||||
|
||||
ans = manager.search(embedding, threshold=0.5)
|
||||
assert ans == name, (name, ans)
|
||||
|
||||
|
||||
def test_3dspeaker_model(model_filename: str):
|
||||
extractor = load_speaker_embedding_model(str(model_filename))
|
||||
manager = sherpa_onnx.SpeakerEmbeddingManager(extractor.dim)
|
||||
|
||||
filenames = [
|
||||
"speaker1_a_cn_16k",
|
||||
"speaker2_a_cn_16k",
|
||||
"speaker1_a_en_16k",
|
||||
"speaker2_a_en_16k",
|
||||
]
|
||||
for filename in filenames:
|
||||
name = filename.rsplit("_", maxsplit=1)[0]
|
||||
data, sample_rate = read_wave(
|
||||
f"/tmp/sr-models/sr-data/test/3d-speaker/{filename}.wav"
|
||||
)
|
||||
stream = extractor.create_stream()
|
||||
stream.accept_waveform(sample_rate=sample_rate, waveform=data)
|
||||
stream.input_finished()
|
||||
assert extractor.is_ready(stream)
|
||||
embedding = extractor.compute(stream)
|
||||
embedding = np.array(embedding)
|
||||
|
||||
status = manager.add(name, embedding)
|
||||
if not status:
|
||||
raise RuntimeError(f"Failed to register speaker {name}")
|
||||
|
||||
filenames = [
|
||||
"speaker1_b_cn_16k",
|
||||
"speaker1_b_en_16k",
|
||||
]
|
||||
for filename in filenames:
|
||||
print(filename)
|
||||
name = filename.rsplit("_", maxsplit=1)[0]
|
||||
name = name.replace("b_cn", "a_cn")
|
||||
name = name.replace("b_en", "a_en")
|
||||
print(name)
|
||||
|
||||
data, sample_rate = read_wave(
|
||||
f"/tmp/sr-models/sr-data/test/3d-speaker/{filename}.wav"
|
||||
)
|
||||
stream = extractor.create_stream()
|
||||
stream.accept_waveform(sample_rate=sample_rate, waveform=data)
|
||||
stream.input_finished()
|
||||
assert extractor.is_ready(stream)
|
||||
embedding = extractor.compute(stream)
|
||||
embedding = np.array(embedding)
|
||||
status = manager.verify(name, embedding, threshold=0.5)
|
||||
if not status:
|
||||
raise RuntimeError(
|
||||
f"Failed to verify {name} with wave {filename}.wav. model: {model_filename}"
|
||||
)
|
||||
|
||||
ans = manager.search(embedding, threshold=0.5)
|
||||
assert ans == name, (name, ans)
|
||||
|
||||
|
||||
class TestSpeakerRecognition(unittest.TestCase):
|
||||
def test_wespeaker_models(self):
|
||||
model_dir = Path(d) / "wespeaker"
|
||||
if not model_dir.is_dir():
|
||||
print(f"{model_dir} does not exist - skip it")
|
||||
return
|
||||
for filename in model_dir.glob("*.onnx"):
|
||||
print(filename)
|
||||
test_wespeaker_model(filename)
|
||||
|
||||
def test_3dpeaker_models(self):
|
||||
model_dir = Path(d) / "3dspeaker"
|
||||
if not model_dir.is_dir():
|
||||
print(f"{model_dir} does not exist - skip it")
|
||||
return
|
||||
for filename in model_dir.glob("*.onnx"):
|
||||
print(filename)
|
||||
test_3dspeaker_model(filename)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
0
sherpa-onnx/python/tests/test_text2token.py
Normal file → Executable file
0
sherpa-onnx/python/tests/test_text2token.py
Normal file → Executable file
Reference in New Issue
Block a user