Add C++ runtime for speaker verification models from NeMo (#527)

This commit is contained in:
Fangjun Kuang
2024-01-13 21:42:09 +08:00
committed by GitHub
parent 68a525a024
commit 2024e96639
20 changed files with 405 additions and 24 deletions

View File

@@ -56,7 +56,7 @@ def load_speaker_embedding_model(model_filename):
return extractor
def test_wespeaker_model(model_filename: str):
def test_zh_models(model_filename: str):
model_filename = str(model_filename)
if "en" in model_filename:
print(f"skip {model_filename}")
@@ -114,8 +114,9 @@ def test_wespeaker_model(model_filename: str):
assert ans == name, (name, ans)
def test_3dspeaker_model(model_filename: str):
extractor = load_speaker_embedding_model(str(model_filename))
def test_en_and_zh_models(model_filename: str):
model_filename = str(model_filename)
extractor = load_speaker_embedding_model(model_filename)
manager = sherpa_onnx.SpeakerEmbeddingManager(extractor.dim)
filenames = [
@@ -124,7 +125,14 @@ def test_3dspeaker_model(model_filename: str):
"speaker1_a_en_16k",
"speaker2_a_en_16k",
]
is_en = "en" in model_filename
for filename in filenames:
if is_en and "cn" in filename:
continue
if not is_en and "en" in filename:
continue
name = filename.rsplit("_", maxsplit=1)[0]
data, sample_rate = read_wave(
f"/tmp/sr-models/sr-data/test/3d-speaker/{filename}.wav"
@@ -145,6 +153,11 @@ def test_3dspeaker_model(model_filename: str):
"speaker1_b_en_16k",
]
for filename in filenames:
if is_en and "cn" in filename:
continue
if not is_en and "en" in filename:
continue
print(filename)
name = filename.rsplit("_", maxsplit=1)[0]
name = name.replace("b_cn", "a_cn")
@@ -178,7 +191,8 @@ class TestSpeakerRecognition(unittest.TestCase):
return
for filename in model_dir.glob("*.onnx"):
print(filename)
test_wespeaker_model(filename)
test_zh_models(filename)
test_en_and_zh_models(filename)
def test_3dpeaker_models(self):
model_dir = Path(d) / "3dspeaker"
@@ -187,7 +201,16 @@ class TestSpeakerRecognition(unittest.TestCase):
return
for filename in model_dir.glob("*.onnx"):
print(filename)
test_3dspeaker_model(filename)
test_en_and_zh_models(filename)
def test_nemo_models(self):
model_dir = Path(d) / "nemo"
if not model_dir.is_dir():
print(f"{model_dir} does not exist - skip it")
return
for filename in model_dir.glob("*.onnx"):
print(filename)
test_en_and_zh_models(filename)
if __name__ == "__main__":