diff --git a/.github/scripts/test-speaker-recognition-python.sh b/.github/scripts/test-speaker-recognition-python.sh new file mode 100755 index 00000000..6131983d --- /dev/null +++ b/.github/scripts/test-speaker-recognition-python.sh @@ -0,0 +1,60 @@ +#!/usr/bin/env bash + +set -e + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +d=/tmp/sr-models +mkdir -p $d + +pushd $d +log "Download test waves" +git clone https://github.com/csukuangfj/sr-data +popd + +log "Download wespeaker models" +model_dir=$d/wespeaker +mkdir -p $model_dir +pushd $model_dir +models=( +en_voxceleb_CAM++.onnx +en_voxceleb_CAM++_LM.onnx +en_voxceleb_resnet152_LM.onnx +en_voxceleb_resnet221_LM.onnx +en_voxceleb_resnet293_LM.onnx +en_voxceleb_resnet34.onnx +en_voxceleb_resnet34_LM.onnx +zh_cnceleb_resnet34.onnx +zh_cnceleb_resnet34_LM.onnx +) +for m in ${models[@]}; do + wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/$m +done +ls -lh +popd + +log "Download 3d-speaker models" +model_dir=$d/3dspeaker +mkdir -p $model_dir +pushd $model_dir +models=( +speech_campplus_sv_en_voxceleb_16k.onnx +speech_campplus_sv_zh-cn_16k-common.onnx +speech_eres2net_base_200k_sv_zh-cn_16k-common.onnx +speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx +speech_eres2net_large_sv_zh-cn_3dspeaker_16k.onnx +speech_eres2net_sv_en_voxceleb_16k.onnx +speech_eres2net_sv_zh-cn_16k-common.onnx +) +for m in ${models[@]}; do + wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/$m +done +ls -lh +popd + + +python3 sherpa-onnx/python/tests/test_speaker_recognition.py --verbose diff --git a/.github/workflows/run-python-test.yaml b/.github/workflows/run-python-test.yaml index ddde2ff0..351c38eb 100644 --- a/.github/workflows/run-python-test.yaml +++ b/.github/workflows/run-python-test.yaml @@ -76,6 +76,7 @@ jobs: - name: Test sherpa-onnx shell: bash run: | + .github/scripts/test-speaker-recognition-python.sh .github/scripts/test-python.sh - uses: actions/upload-artifact@v3 diff --git a/sherpa-onnx/csrc/CMakeLists.txt b/sherpa-onnx/csrc/CMakeLists.txt index ba0c5645..38fef1c5 100644 --- a/sherpa-onnx/csrc/CMakeLists.txt +++ b/sherpa-onnx/csrc/CMakeLists.txt @@ -99,7 +99,7 @@ set(sources # speaker embedding extractor list(APPEND sources speaker-embedding-extractor-impl.cc - speaker-embedding-extractor-wespeaker-model.cc + speaker-embedding-extractor-model.cc speaker-embedding-extractor.cc speaker-embedding-manager.cc ) diff --git a/sherpa-onnx/csrc/speaker-embedding-extractor-wespeaker-impl.h b/sherpa-onnx/csrc/speaker-embedding-extractor-general-impl.h similarity index 61% rename from sherpa-onnx/csrc/speaker-embedding-extractor-wespeaker-impl.h rename to sherpa-onnx/csrc/speaker-embedding-extractor-general-impl.h index b408d9de..eb87d904 100644 --- a/sherpa-onnx/csrc/speaker-embedding-extractor-wespeaker-impl.h +++ b/sherpa-onnx/csrc/speaker-embedding-extractor-general-impl.h @@ -1,23 +1,24 @@ -// sherpa-onnx/csrc/speaker-embedding-extractor-wespeaker-impl.h +// sherpa-onnx/csrc/speaker-embedding-extractor-general-impl.h // // Copyright (c) 2023 Xiaomi Corporation -#ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_WESPEAKER_IMPL_H_ -#define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_WESPEAKER_IMPL_H_ +#ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_GENERAL_IMPL_H_ +#define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_GENERAL_IMPL_H_ #include #include #include #include +#include "Eigen/Dense" #include "sherpa-onnx/csrc/speaker-embedding-extractor-impl.h" -#include "sherpa-onnx/csrc/speaker-embedding-extractor-wespeaker-model.h" +#include "sherpa-onnx/csrc/speaker-embedding-extractor-model.h" namespace sherpa_onnx { -class SpeakerEmbeddingExtractorWeSpeakerImpl +class SpeakerEmbeddingExtractorGeneralImpl : public SpeakerEmbeddingExtractorImpl { public: - explicit SpeakerEmbeddingExtractorWeSpeakerImpl( + explicit SpeakerEmbeddingExtractorGeneralImpl( const SpeakerEmbeddingExtractorConfig &config) : model_(config) {} @@ -25,7 +26,7 @@ class SpeakerEmbeddingExtractorWeSpeakerImpl std::unique_ptr CreateStream() const override { FeatureExtractorConfig feat_config; - auto meta_data = model_.GetMetaData(); + const auto &meta_data = model_.GetMetaData(); feat_config.sampling_rate = meta_data.sample_rate; feat_config.normalize_samples = meta_data.normalize_samples; @@ -52,6 +53,17 @@ class SpeakerEmbeddingExtractorWeSpeakerImpl int32_t feat_dim = features.size() / num_frames; + const auto &meta_data = model_.GetMetaData(); + if (!meta_data.feature_normalize_type.empty()) { + if (meta_data.feature_normalize_type == "global-mean") { + SubtractGlobalMean(features.data(), num_frames, feat_dim); + } else { + SHERPA_ONNX_LOGE("Unsupported feature_normalize_type: %s", + meta_data.feature_normalize_type.c_str()); + exit(-1); + } + } + auto memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); @@ -71,9 +83,19 @@ class SpeakerEmbeddingExtractorWeSpeakerImpl } private: - SpeakerEmbeddingExtractorWeSpeakerModel model_; + void SubtractGlobalMean(float *p, int32_t num_frames, + int32_t feat_dim) const { + auto m = Eigen::Map< + Eigen::Matrix>( + p, num_frames, feat_dim); + + m = m.rowwise() - m.colwise().mean(); + } + + private: + SpeakerEmbeddingExtractorModel model_; }; } // namespace sherpa_onnx -#endif // SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_WESPEAKER_IMPL_H_ +#endif // SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_GENERAL_IMPL_H_ diff --git a/sherpa-onnx/csrc/speaker-embedding-extractor-impl.cc b/sherpa-onnx/csrc/speaker-embedding-extractor-impl.cc index 6dff5ac5..46cdfa61 100644 --- a/sherpa-onnx/csrc/speaker-embedding-extractor-impl.cc +++ b/sherpa-onnx/csrc/speaker-embedding-extractor-impl.cc @@ -5,7 +5,7 @@ #include "sherpa-onnx/csrc/macros.h" #include "sherpa-onnx/csrc/onnx-utils.h" -#include "sherpa-onnx/csrc/speaker-embedding-extractor-wespeaker-impl.h" +#include "sherpa-onnx/csrc/speaker-embedding-extractor-general-impl.h" namespace sherpa_onnx { @@ -13,6 +13,7 @@ namespace { enum class ModelType { kWeSpeaker, + k3dSpeaker, kUnkown, }; @@ -49,6 +50,8 @@ static ModelType GetModelType(char *model_data, size_t model_data_length, if (model_type.get() == std::string("wespeaker")) { return ModelType::kWeSpeaker; + } else if (model_type.get() == std::string("3d-speaker")) { + return ModelType::k3dSpeaker; } else { SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get()); return ModelType::kUnkown; @@ -68,7 +71,9 @@ SpeakerEmbeddingExtractorImpl::Create( switch (model_type) { case ModelType::kWeSpeaker: - return std::make_unique(config); + // fall through + case ModelType::k3dSpeaker: + return std::make_unique(config); case ModelType::kUnkown: SHERPA_ONNX_LOGE( "Unknown model type in for speaker embedding extractor!"); diff --git a/sherpa-onnx/csrc/speaker-embedding-extractor-model-meta-data.h b/sherpa-onnx/csrc/speaker-embedding-extractor-model-meta-data.h new file mode 100644 index 00000000..530938b5 --- /dev/null +++ b/sherpa-onnx/csrc/speaker-embedding-extractor-model-meta-data.h @@ -0,0 +1,28 @@ +// sherpa-onnx/csrc/speaker-embedding-extractor-model-meta-data.h +// +// Copyright (c) 2023 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_MODEL_META_DATA_H_ +#define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_MODEL_META_DATA_H_ + +#include +#include + +namespace sherpa_onnx { + +struct SpeakerEmbeddingExtractorModelMetaData { + int32_t output_dim = 0; + int32_t sample_rate = 0; + + // for wespeaker models, it is 0; + // for 3d-speaker models, it is 1 + int32_t normalize_samples = 1; + + // Chinese, English, etc. + std::string language; + + // for 3d-speaker, it is global-mean + std::string feature_normalize_type; +}; + +} // namespace sherpa_onnx +#endif // SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_MODEL_META_DATA_H_ diff --git a/sherpa-onnx/csrc/speaker-embedding-extractor-wespeaker-model.cc b/sherpa-onnx/csrc/speaker-embedding-extractor-model.cc similarity index 70% rename from sherpa-onnx/csrc/speaker-embedding-extractor-wespeaker-model.cc rename to sherpa-onnx/csrc/speaker-embedding-extractor-model.cc index b934f28a..fedfcab5 100644 --- a/sherpa-onnx/csrc/speaker-embedding-extractor-wespeaker-model.cc +++ b/sherpa-onnx/csrc/speaker-embedding-extractor-model.cc @@ -1,8 +1,8 @@ -// sherpa-onnx/csrc/speaker-embedding-extractor-wespeaker-model.cc +// sherpa-onnx/csrc/speaker-embedding-extractor-model.cc // -// Copyright (c) 2023 Xiaomi Corporation +// Copyright (c) 2023-2024 Xiaomi Corporation -#include "sherpa-onnx/csrc/speaker-embedding-extractor-wespeaker-model.h" +#include "sherpa-onnx/csrc/speaker-embedding-extractor-model.h" #include #include @@ -11,11 +11,11 @@ #include "sherpa-onnx/csrc/macros.h" #include "sherpa-onnx/csrc/onnx-utils.h" #include "sherpa-onnx/csrc/session.h" -#include "sherpa-onnx/csrc/speaker-embedding-extractor-wespeaker-model-metadata.h" +#include "sherpa-onnx/csrc/speaker-embedding-extractor-model-meta-data.h" namespace sherpa_onnx { -class SpeakerEmbeddingExtractorWeSpeakerModel::Impl { +class SpeakerEmbeddingExtractorModel::Impl { public: explicit Impl(const SpeakerEmbeddingExtractorConfig &config) : config_(config), @@ -37,7 +37,7 @@ class SpeakerEmbeddingExtractorWeSpeakerModel::Impl { return std::move(outputs[0]); } - const SpeakerEmbeddingExtractorWeSpeakerModelMetaData &GetMetaData() const { + const SpeakerEmbeddingExtractorModelMetaData &GetMetaData() const { return meta_data_; } @@ -65,10 +65,13 @@ class SpeakerEmbeddingExtractorWeSpeakerModel::Impl { "normalize_samples"); SHERPA_ONNX_READ_META_DATA_STR(meta_data_.language, "language"); + SHERPA_ONNX_READ_META_DATA_STR_WITH_DEFAULT( + meta_data_.feature_normalize_type, "feature_normalize_type", ""); + std::string framework; SHERPA_ONNX_READ_META_DATA_STR(framework, "framework"); - if (framework != "wespeaker") { - SHERPA_ONNX_LOGE("Expect a wespeaker model, given: %s", + if (framework != "wespeaker" && framework != "3d-speaker") { + SHERPA_ONNX_LOGE("Expect a wespeaker or a 3d-speaker model, given: %s", framework.c_str()); exit(-1); } @@ -88,24 +91,21 @@ class SpeakerEmbeddingExtractorWeSpeakerModel::Impl { std::vector output_names_; std::vector output_names_ptr_; - SpeakerEmbeddingExtractorWeSpeakerModelMetaData meta_data_; + SpeakerEmbeddingExtractorModelMetaData meta_data_; }; -SpeakerEmbeddingExtractorWeSpeakerModel:: - SpeakerEmbeddingExtractorWeSpeakerModel( - const SpeakerEmbeddingExtractorConfig &config) +SpeakerEmbeddingExtractorModel::SpeakerEmbeddingExtractorModel( + const SpeakerEmbeddingExtractorConfig &config) : impl_(std::make_unique(config)) {} -SpeakerEmbeddingExtractorWeSpeakerModel:: - ~SpeakerEmbeddingExtractorWeSpeakerModel() = default; +SpeakerEmbeddingExtractorModel::~SpeakerEmbeddingExtractorModel() = default; -const SpeakerEmbeddingExtractorWeSpeakerModelMetaData & -SpeakerEmbeddingExtractorWeSpeakerModel::GetMetaData() const { +const SpeakerEmbeddingExtractorModelMetaData & +SpeakerEmbeddingExtractorModel::GetMetaData() const { return impl_->GetMetaData(); } -Ort::Value SpeakerEmbeddingExtractorWeSpeakerModel::Compute( - Ort::Value x) const { +Ort::Value SpeakerEmbeddingExtractorModel::Compute(Ort::Value x) const { return impl_->Compute(std::move(x)); } diff --git a/sherpa-onnx/csrc/speaker-embedding-extractor-model.h b/sherpa-onnx/csrc/speaker-embedding-extractor-model.h new file mode 100644 index 00000000..3fa94ef3 --- /dev/null +++ b/sherpa-onnx/csrc/speaker-embedding-extractor-model.h @@ -0,0 +1,37 @@ +// sherpa-onnx/csrc/speaker-embedding-extractor-model.h +// +// Copyright (c) 2023-2024 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_MODEL_H_ +#define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_MODEL_H_ + +#include + +#include "onnxruntime_cxx_api.h" // NOLINT +#include "sherpa-onnx/csrc/speaker-embedding-extractor-model-meta-data.h" +#include "sherpa-onnx/csrc/speaker-embedding-extractor.h" + +namespace sherpa_onnx { + +class SpeakerEmbeddingExtractorModel { + public: + explicit SpeakerEmbeddingExtractorModel( + const SpeakerEmbeddingExtractorConfig &config); + + ~SpeakerEmbeddingExtractorModel(); + + const SpeakerEmbeddingExtractorModelMetaData &GetMetaData() const; + + /** + * @param x A float32 tensor of shape (N, T, C) + * @return A float32 tensor of shape (N, C) + */ + Ort::Value Compute(Ort::Value x) const; + + private: + class Impl; + std::unique_ptr impl_; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_MODEL_H_ diff --git a/sherpa-onnx/csrc/speaker-embedding-extractor-wespeaker-model-metadata.h b/sherpa-onnx/csrc/speaker-embedding-extractor-wespeaker-model-metadata.h deleted file mode 100644 index 32ee76c6..00000000 --- a/sherpa-onnx/csrc/speaker-embedding-extractor-wespeaker-model-metadata.h +++ /dev/null @@ -1,20 +0,0 @@ -// sherpa-onnx/csrc/speaker-embedding-extractor-wespeaker-model-metadata.h -// -// Copyright (c) 2023 Xiaomi Corporation -#ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_WESPEAKER_MODEL_METADATA_H_ -#define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_WESPEAKER_MODEL_METADATA_H_ - -#include -#include - -namespace sherpa_onnx { - -struct SpeakerEmbeddingExtractorWeSpeakerModelMetaData { - int32_t output_dim = 0; - int32_t sample_rate = 0; - int32_t normalize_samples = 0; - std::string language; -}; - -} // namespace sherpa_onnx -#endif // SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_WESPEAKER_MODEL_METADATA_H_ diff --git a/sherpa-onnx/csrc/speaker-embedding-extractor-wespeaker-model.h b/sherpa-onnx/csrc/speaker-embedding-extractor-wespeaker-model.h deleted file mode 100644 index f0b910f3..00000000 --- a/sherpa-onnx/csrc/speaker-embedding-extractor-wespeaker-model.h +++ /dev/null @@ -1,37 +0,0 @@ -// sherpa-onnx/csrc/speaker-embedding-extractor-wespeaker-model.h -// -// Copyright (c) 2023 Xiaomi Corporation -#ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_WESPEAKER_MODEL_H_ -#define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_WESPEAKER_MODEL_H_ - -#include - -#include "onnxruntime_cxx_api.h" // NOLINT -#include "sherpa-onnx/csrc/speaker-embedding-extractor-wespeaker-model-metadata.h" -#include "sherpa-onnx/csrc/speaker-embedding-extractor.h" - -namespace sherpa_onnx { - -class SpeakerEmbeddingExtractorWeSpeakerModel { - public: - explicit SpeakerEmbeddingExtractorWeSpeakerModel( - const SpeakerEmbeddingExtractorConfig &config); - - ~SpeakerEmbeddingExtractorWeSpeakerModel(); - - const SpeakerEmbeddingExtractorWeSpeakerModelMetaData &GetMetaData() const; - - /** - * @param x A float32 tensor of shape (N, T, C) - * @return A float32 tensor of shape (N, C) - */ - Ort::Value Compute(Ort::Value x) const; - - private: - class Impl; - std::unique_ptr impl_; -}; - -} // namespace sherpa_onnx - -#endif // SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_WESPEAKER_MODEL_H_ diff --git a/sherpa-onnx/python/tests/CMakeLists.txt b/sherpa-onnx/python/tests/CMakeLists.txt index 4fd28529..e99636e2 100644 --- a/sherpa-onnx/python/tests/CMakeLists.txt +++ b/sherpa-onnx/python/tests/CMakeLists.txt @@ -23,6 +23,7 @@ set(py_test_files test_offline_recognizer.py test_online_recognizer.py test_online_transducer_model_config.py + test_speaker_recognition.py test_text2token.py ) diff --git a/sherpa-onnx/python/tests/test_feature_extractor_config.py b/sherpa-onnx/python/tests/test_feature_extractor_config.py old mode 100644 new mode 100755 diff --git a/sherpa-onnx/python/tests/test_online_transducer_model_config.py b/sherpa-onnx/python/tests/test_online_transducer_model_config.py old mode 100644 new mode 100755 diff --git a/sherpa-onnx/python/tests/test_speaker_recognition.py b/sherpa-onnx/python/tests/test_speaker_recognition.py new file mode 100755 index 00000000..e05ae2a0 --- /dev/null +++ b/sherpa-onnx/python/tests/test_speaker_recognition.py @@ -0,0 +1,194 @@ +# sherpa-onnx/python/tests/test_speaker_recognition.py +# +# Copyright (c) 2024 Xiaomi Corporation +# +# To run this single test, use +# +# ctest --verbose -R test_speaker_recognition_py + +import unittest +import wave +from collections import defaultdict +from pathlib import Path +from typing import Tuple + +import numpy as np +import sherpa_onnx + +d = "/tmp/sr-models" + + +def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]: + """ + Args: + wave_filename: + Path to a wave file. It should be single channel and each sample should + be 16-bit. Its sample rate does not need to be 16kHz. + Returns: + Return a tuple containing: + - A 1-D array of dtype np.float32 containing the samples, which are + normalized to the range [-1, 1]. + - sample rate of the wave file + """ + + with wave.open(wave_filename) as f: + assert f.getnchannels() == 1, f.getnchannels() + assert f.getsampwidth() == 2, f.getsampwidth() # it is in bytes + num_samples = f.getnframes() + samples = f.readframes(num_samples) + samples_int16 = np.frombuffer(samples, dtype=np.int16) + samples_float32 = samples_int16.astype(np.float32) + + samples_float32 = samples_float32 / 32768 + return samples_float32, f.getframerate() + + +def load_speaker_embedding_model(model_filename): + config = sherpa_onnx.SpeakerEmbeddingExtractorConfig( + model=model_filename, + num_threads=1, + debug=True, + provider="cpu", + ) + if not config.validate(): + raise ValueError(f"Invalid config. {config}") + extractor = sherpa_onnx.SpeakerEmbeddingExtractor(config) + return extractor + + +def test_wespeaker_model(model_filename: str): + model_filename = str(model_filename) + if "en" in model_filename: + print(f"skip {model_filename}") + return + extractor = load_speaker_embedding_model(model_filename) + filenames = [ + "leijun-sr-1", + "leijun-sr-2", + "fangjun-sr-1", + "fangjun-sr-2", + "fangjun-sr-3", + ] + tmp = defaultdict(list) + for filename in filenames: + print(filename) + name = filename.split("-", maxsplit=1)[0] + data, sample_rate = read_wave(f"/tmp/sr-models/sr-data/enroll/{filename}.wav") + stream = extractor.create_stream() + stream.accept_waveform(sample_rate=sample_rate, waveform=data) + stream.input_finished() + assert extractor.is_ready(stream) + embedding = extractor.compute(stream) + embedding = np.array(embedding) + tmp[name].append(embedding) + + manager = sherpa_onnx.SpeakerEmbeddingManager(extractor.dim) + for name, embedding_list in tmp.items(): + print(name, len(embedding_list)) + embedding = sum(embedding_list) / len(embedding_list) + status = manager.add(name, embedding) + if not status: + raise RuntimeError(f"Failed to register speaker {name}") + + filenames = [ + "leijun-test-sr-1", + "leijun-test-sr-2", + "leijun-test-sr-3", + "fangjun-test-sr-1", + "fangjun-test-sr-2", + ] + for filename in filenames: + name = filename.split("-", maxsplit=1)[0] + data, sample_rate = read_wave(f"/tmp/sr-models/sr-data/test/{filename}.wav") + stream = extractor.create_stream() + stream.accept_waveform(sample_rate=sample_rate, waveform=data) + stream.input_finished() + assert extractor.is_ready(stream) + embedding = extractor.compute(stream) + embedding = np.array(embedding) + status = manager.verify(name, embedding, threshold=0.5) + if not status: + raise RuntimeError(f"Failed to verify {name} with wave {filename}.wav") + + ans = manager.search(embedding, threshold=0.5) + assert ans == name, (name, ans) + + +def test_3dspeaker_model(model_filename: str): + extractor = load_speaker_embedding_model(str(model_filename)) + manager = sherpa_onnx.SpeakerEmbeddingManager(extractor.dim) + + filenames = [ + "speaker1_a_cn_16k", + "speaker2_a_cn_16k", + "speaker1_a_en_16k", + "speaker2_a_en_16k", + ] + for filename in filenames: + name = filename.rsplit("_", maxsplit=1)[0] + data, sample_rate = read_wave( + f"/tmp/sr-models/sr-data/test/3d-speaker/{filename}.wav" + ) + stream = extractor.create_stream() + stream.accept_waveform(sample_rate=sample_rate, waveform=data) + stream.input_finished() + assert extractor.is_ready(stream) + embedding = extractor.compute(stream) + embedding = np.array(embedding) + + status = manager.add(name, embedding) + if not status: + raise RuntimeError(f"Failed to register speaker {name}") + + filenames = [ + "speaker1_b_cn_16k", + "speaker1_b_en_16k", + ] + for filename in filenames: + print(filename) + name = filename.rsplit("_", maxsplit=1)[0] + name = name.replace("b_cn", "a_cn") + name = name.replace("b_en", "a_en") + print(name) + + data, sample_rate = read_wave( + f"/tmp/sr-models/sr-data/test/3d-speaker/{filename}.wav" + ) + stream = extractor.create_stream() + stream.accept_waveform(sample_rate=sample_rate, waveform=data) + stream.input_finished() + assert extractor.is_ready(stream) + embedding = extractor.compute(stream) + embedding = np.array(embedding) + status = manager.verify(name, embedding, threshold=0.5) + if not status: + raise RuntimeError( + f"Failed to verify {name} with wave {filename}.wav. model: {model_filename}" + ) + + ans = manager.search(embedding, threshold=0.5) + assert ans == name, (name, ans) + + +class TestSpeakerRecognition(unittest.TestCase): + def test_wespeaker_models(self): + model_dir = Path(d) / "wespeaker" + if not model_dir.is_dir(): + print(f"{model_dir} does not exist - skip it") + return + for filename in model_dir.glob("*.onnx"): + print(filename) + test_wespeaker_model(filename) + + def test_3dpeaker_models(self): + model_dir = Path(d) / "3dspeaker" + if not model_dir.is_dir(): + print(f"{model_dir} does not exist - skip it") + return + for filename in model_dir.glob("*.onnx"): + print(filename) + test_3dspeaker_model(filename) + + +if __name__ == "__main__": + unittest.main() diff --git a/sherpa-onnx/python/tests/test_text2token.py b/sherpa-onnx/python/tests/test_text2token.py old mode 100644 new mode 100755