Add C++ runtime for speaker verification models from NeMo (#527)
This commit is contained in:
@@ -57,5 +57,19 @@ done
|
|||||||
ls -lh
|
ls -lh
|
||||||
popd
|
popd
|
||||||
|
|
||||||
|
log "Download NeMo models"
|
||||||
|
model_dir=$d/nemo
|
||||||
|
mkdir -p $model_dir
|
||||||
|
pushd $model_dir
|
||||||
|
models=(
|
||||||
|
nemo_en_titanet_large.onnx
|
||||||
|
nemo_en_titanet_small.onnx
|
||||||
|
nemo_en_speakerverification_speakernet.onnx
|
||||||
|
)
|
||||||
|
for m in ${models[@]}; do
|
||||||
|
wget -q https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/$m
|
||||||
|
done
|
||||||
|
ls -lh
|
||||||
|
popd
|
||||||
|
|
||||||
python3 sherpa-onnx/python/tests/test_speaker_recognition.py --verbose
|
python3 sherpa-onnx/python/tests/test_speaker_recognition.py --verbose
|
||||||
|
|||||||
@@ -1,9 +1,9 @@
|
|||||||
function(download_kaldi_native_fbank)
|
function(download_kaldi_native_fbank)
|
||||||
include(FetchContent)
|
include(FetchContent)
|
||||||
|
|
||||||
set(kaldi_native_fbank_URL "https://github.com/csukuangfj/kaldi-native-fbank/archive/refs/tags/v1.18.5.tar.gz")
|
set(kaldi_native_fbank_URL "https://github.com/csukuangfj/kaldi-native-fbank/archive/refs/tags/v1.18.6.tar.gz")
|
||||||
set(kaldi_native_fbank_URL2 "https://huggingface.co/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/kaldi-native-fbank-1.18.5.tar.gz")
|
set(kaldi_native_fbank_URL2 "https://huggingface.co/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/kaldi-native-fbank-1.18.6.tar.gz")
|
||||||
set(kaldi_native_fbank_HASH "SHA256=dce0cb3bc6fece5d8053d8780cb4ce22da57cb57ebec332641661521a0425283")
|
set(kaldi_native_fbank_HASH "SHA256=6202a00cd06ba8ff89beb7b6f85cda34e073e94f25fc29e37c519bff0706bf19")
|
||||||
|
|
||||||
set(KALDI_NATIVE_FBANK_BUILD_TESTS OFF CACHE BOOL "" FORCE)
|
set(KALDI_NATIVE_FBANK_BUILD_TESTS OFF CACHE BOOL "" FORCE)
|
||||||
set(KALDI_NATIVE_FBANK_BUILD_PYTHON OFF CACHE BOOL "" FORCE)
|
set(KALDI_NATIVE_FBANK_BUILD_PYTHON OFF CACHE BOOL "" FORCE)
|
||||||
@@ -12,11 +12,11 @@ function(download_kaldi_native_fbank)
|
|||||||
# If you don't have access to the Internet,
|
# If you don't have access to the Internet,
|
||||||
# please pre-download kaldi-native-fbank
|
# please pre-download kaldi-native-fbank
|
||||||
set(possible_file_locations
|
set(possible_file_locations
|
||||||
$ENV{HOME}/Downloads/kaldi-native-fbank-1.18.5.tar.gz
|
$ENV{HOME}/Downloads/kaldi-native-fbank-1.18.6.tar.gz
|
||||||
${PROJECT_SOURCE_DIR}/kaldi-native-fbank-1.18.5.tar.gz
|
${PROJECT_SOURCE_DIR}/kaldi-native-fbank-1.18.6.tar.gz
|
||||||
${PROJECT_BINARY_DIR}/kaldi-native-fbank-1.18.5.tar.gz
|
${PROJECT_BINARY_DIR}/kaldi-native-fbank-1.18.6.tar.gz
|
||||||
/tmp/kaldi-native-fbank-1.18.5.tar.gz
|
/tmp/kaldi-native-fbank-1.18.6.tar.gz
|
||||||
/star-fj/fangjun/download/github/kaldi-native-fbank-1.18.5.tar.gz
|
/star-fj/fangjun/download/github/kaldi-native-fbank-1.18.6.tar.gz
|
||||||
)
|
)
|
||||||
|
|
||||||
foreach(f IN LISTS possible_file_locations)
|
foreach(f IN LISTS possible_file_locations)
|
||||||
|
|||||||
@@ -100,6 +100,7 @@ set(sources
|
|||||||
list(APPEND sources
|
list(APPEND sources
|
||||||
speaker-embedding-extractor-impl.cc
|
speaker-embedding-extractor-impl.cc
|
||||||
speaker-embedding-extractor-model.cc
|
speaker-embedding-extractor-model.cc
|
||||||
|
speaker-embedding-extractor-nemo-model.cc
|
||||||
speaker-embedding-extractor.cc
|
speaker-embedding-extractor.cc
|
||||||
speaker-embedding-manager.cc
|
speaker-embedding-manager.cc
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -41,8 +41,12 @@ class FeatureExtractor::Impl {
|
|||||||
public:
|
public:
|
||||||
explicit Impl(const FeatureExtractorConfig &config) : config_(config) {
|
explicit Impl(const FeatureExtractorConfig &config) : config_(config) {
|
||||||
opts_.frame_opts.dither = 0;
|
opts_.frame_opts.dither = 0;
|
||||||
opts_.frame_opts.snip_edges = false;
|
opts_.frame_opts.snip_edges = config.snip_edges;
|
||||||
opts_.frame_opts.samp_freq = config.sampling_rate;
|
opts_.frame_opts.samp_freq = config.sampling_rate;
|
||||||
|
opts_.frame_opts.frame_shift_ms = config.frame_shift_ms;
|
||||||
|
opts_.frame_opts.frame_length_ms = config.frame_length_ms;
|
||||||
|
opts_.frame_opts.remove_dc_offset = config.remove_dc_offset;
|
||||||
|
opts_.frame_opts.window_type = config.window_type;
|
||||||
|
|
||||||
opts_.mel_opts.num_bins = config.feature_dim;
|
opts_.mel_opts.num_bins = config.feature_dim;
|
||||||
|
|
||||||
@@ -52,6 +56,9 @@ class FeatureExtractor::Impl {
|
|||||||
// https://github.com/k2-fsa/sherpa-onnx/issues/514
|
// https://github.com/k2-fsa/sherpa-onnx/issues/514
|
||||||
opts_.mel_opts.high_freq = -400;
|
opts_.mel_opts.high_freq = -400;
|
||||||
|
|
||||||
|
opts_.mel_opts.low_freq = config.low_freq;
|
||||||
|
opts_.mel_opts.is_librosa = config.is_librosa;
|
||||||
|
|
||||||
fbank_ = std::make_unique<knf::OnlineFbank>(opts_);
|
fbank_ = std::make_unique<knf::OnlineFbank>(opts_);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -28,6 +28,14 @@ struct FeatureExtractorConfig {
|
|||||||
// If false, we will multiply the inputs by 32768
|
// If false, we will multiply the inputs by 32768
|
||||||
bool normalize_samples = true;
|
bool normalize_samples = true;
|
||||||
|
|
||||||
|
bool snip_edges = false;
|
||||||
|
float frame_shift_ms = 10.0f; // in milliseconds.
|
||||||
|
float frame_length_ms = 25.0f; // in milliseconds.
|
||||||
|
int32_t low_freq = 20;
|
||||||
|
bool is_librosa = false;
|
||||||
|
bool remove_dc_offset = true; // Subtract mean of wave before FFT.
|
||||||
|
std::string window_type = "povey"; // e.g. Hamming window
|
||||||
|
|
||||||
std::string ToString() const;
|
std::string ToString() const;
|
||||||
|
|
||||||
void Register(ParseOptions *po);
|
void Register(ParseOptions *po);
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
// sherpa-onnx/csrc/speaker-embedding-extractor-general-impl.h
|
// sherpa-onnx/csrc/speaker-embedding-extractor-general-impl.h
|
||||||
//
|
//
|
||||||
// Copyright (c) 2023 Xiaomi Corporation
|
// Copyright (c) 2024 Xiaomi Corporation
|
||||||
|
|
||||||
#ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_GENERAL_IMPL_H_
|
#ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_GENERAL_IMPL_H_
|
||||||
#define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_GENERAL_IMPL_H_
|
#define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_GENERAL_IMPL_H_
|
||||||
|
|||||||
@@ -1,11 +1,12 @@
|
|||||||
// sherpa-onnx/csrc/speaker-embedding-extractor-impl.cc
|
// sherpa-onnx/csrc/speaker-embedding-extractor-impl.cc
|
||||||
//
|
//
|
||||||
// Copyright (c) 2023 Xiaomi Corporation
|
// Copyright (c) 2024 Xiaomi Corporation
|
||||||
#include "sherpa-onnx/csrc/speaker-embedding-extractor-impl.h"
|
#include "sherpa-onnx/csrc/speaker-embedding-extractor-impl.h"
|
||||||
|
|
||||||
#include "sherpa-onnx/csrc/macros.h"
|
#include "sherpa-onnx/csrc/macros.h"
|
||||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||||
#include "sherpa-onnx/csrc/speaker-embedding-extractor-general-impl.h"
|
#include "sherpa-onnx/csrc/speaker-embedding-extractor-general-impl.h"
|
||||||
|
#include "sherpa-onnx/csrc/speaker-embedding-extractor-nemo-impl.h"
|
||||||
|
|
||||||
namespace sherpa_onnx {
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
@@ -14,6 +15,7 @@ namespace {
|
|||||||
enum class ModelType {
|
enum class ModelType {
|
||||||
kWeSpeaker,
|
kWeSpeaker,
|
||||||
k3dSpeaker,
|
k3dSpeaker,
|
||||||
|
kNeMo,
|
||||||
kUnkown,
|
kUnkown,
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -52,6 +54,8 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
|
|||||||
return ModelType::kWeSpeaker;
|
return ModelType::kWeSpeaker;
|
||||||
} else if (model_type.get() == std::string("3d-speaker")) {
|
} else if (model_type.get() == std::string("3d-speaker")) {
|
||||||
return ModelType::k3dSpeaker;
|
return ModelType::k3dSpeaker;
|
||||||
|
} else if (model_type.get() == std::string("nemo")) {
|
||||||
|
return ModelType::kNeMo;
|
||||||
} else {
|
} else {
|
||||||
SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get());
|
SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get());
|
||||||
return ModelType::kUnkown;
|
return ModelType::kUnkown;
|
||||||
@@ -74,6 +78,8 @@ SpeakerEmbeddingExtractorImpl::Create(
|
|||||||
// fall through
|
// fall through
|
||||||
case ModelType::k3dSpeaker:
|
case ModelType::k3dSpeaker:
|
||||||
return std::make_unique<SpeakerEmbeddingExtractorGeneralImpl>(config);
|
return std::make_unique<SpeakerEmbeddingExtractorGeneralImpl>(config);
|
||||||
|
case ModelType::kNeMo:
|
||||||
|
return std::make_unique<SpeakerEmbeddingExtractorNeMoImpl>(config);
|
||||||
case ModelType::kUnkown:
|
case ModelType::kUnkown:
|
||||||
SHERPA_ONNX_LOGE(
|
SHERPA_ONNX_LOGE(
|
||||||
"Unknown model type in for speaker embedding extractor!");
|
"Unknown model type in for speaker embedding extractor!");
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
// sherpa-onnx/csrc/speaker-embedding-extractor-impl.h
|
// sherpa-onnx/csrc/speaker-embedding-extractor-impl.h
|
||||||
//
|
//
|
||||||
// Copyright (c) 2023 Xiaomi Corporation
|
// Copyright (c) 2024 Xiaomi Corporation
|
||||||
|
|
||||||
#ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_IMPL_H_
|
#ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_IMPL_H_
|
||||||
#define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_IMPL_H_
|
#define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_IMPL_H_
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
// sherpa-onnx/csrc/speaker-embedding-extractor-model.cc
|
// sherpa-onnx/csrc/speaker-embedding-extractor-model.cc
|
||||||
//
|
//
|
||||||
// Copyright (c) 2023-2024 Xiaomi Corporation
|
// Copyright (c) 2024 Xiaomi Corporation
|
||||||
|
|
||||||
#include "sherpa-onnx/csrc/speaker-embedding-extractor-model.h"
|
#include "sherpa-onnx/csrc/speaker-embedding-extractor-model.h"
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
// sherpa-onnx/csrc/speaker-embedding-extractor-model.h
|
// sherpa-onnx/csrc/speaker-embedding-extractor-model.h
|
||||||
//
|
//
|
||||||
// Copyright (c) 2023-2024 Xiaomi Corporation
|
// Copyright (c) 2024 Xiaomi Corporation
|
||||||
#ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_MODEL_H_
|
#ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_MODEL_H_
|
||||||
#define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_MODEL_H_
|
#define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_MODEL_H_
|
||||||
|
|
||||||
|
|||||||
128
sherpa-onnx/csrc/speaker-embedding-extractor-nemo-impl.h
Normal file
128
sherpa-onnx/csrc/speaker-embedding-extractor-nemo-impl.h
Normal file
@@ -0,0 +1,128 @@
|
|||||||
|
// sherpa-onnx/csrc/speaker-embedding-extractor-nemo-impl.h
|
||||||
|
//
|
||||||
|
// Copyright (c) 2024 Xiaomi Corporation
|
||||||
|
|
||||||
|
#ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_NEMO_IMPL_H_
|
||||||
|
#define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_NEMO_IMPL_H_
|
||||||
|
#include <algorithm>
|
||||||
|
#include <memory>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "Eigen/Dense"
|
||||||
|
#include "sherpa-onnx/csrc/speaker-embedding-extractor-impl.h"
|
||||||
|
#include "sherpa-onnx/csrc/speaker-embedding-extractor-nemo-model.h"
|
||||||
|
#include "sherpa-onnx/csrc/transpose.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
class SpeakerEmbeddingExtractorNeMoImpl : public SpeakerEmbeddingExtractorImpl {
|
||||||
|
public:
|
||||||
|
explicit SpeakerEmbeddingExtractorNeMoImpl(
|
||||||
|
const SpeakerEmbeddingExtractorConfig &config)
|
||||||
|
: model_(config) {}
|
||||||
|
|
||||||
|
int32_t Dim() const override { return model_.GetMetaData().output_dim; }
|
||||||
|
|
||||||
|
std::unique_ptr<OnlineStream> CreateStream() const override {
|
||||||
|
FeatureExtractorConfig feat_config;
|
||||||
|
const auto &meta_data = model_.GetMetaData();
|
||||||
|
feat_config.sampling_rate = meta_data.sample_rate;
|
||||||
|
feat_config.feature_dim = meta_data.feat_dim;
|
||||||
|
feat_config.normalize_samples = true;
|
||||||
|
feat_config.snip_edges = true;
|
||||||
|
feat_config.frame_shift_ms = meta_data.window_stride_ms;
|
||||||
|
feat_config.frame_length_ms = meta_data.window_size_ms;
|
||||||
|
feat_config.low_freq = 0;
|
||||||
|
feat_config.is_librosa = true;
|
||||||
|
feat_config.remove_dc_offset = false;
|
||||||
|
feat_config.window_type = meta_data.window_type;
|
||||||
|
|
||||||
|
return std::make_unique<OnlineStream>(feat_config);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool IsReady(OnlineStream *s) const override {
|
||||||
|
return s->GetNumProcessedFrames() < s->NumFramesReady();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<float> Compute(OnlineStream *s) const override {
|
||||||
|
int32_t num_frames = s->NumFramesReady() - s->GetNumProcessedFrames();
|
||||||
|
if (num_frames <= 0) {
|
||||||
|
SHERPA_ONNX_LOGE(
|
||||||
|
"Please make sure IsReady(s) returns true. num_frames: %d",
|
||||||
|
num_frames);
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<float> features =
|
||||||
|
s->GetFrames(s->GetNumProcessedFrames(), num_frames);
|
||||||
|
|
||||||
|
s->GetNumProcessedFrames() += num_frames;
|
||||||
|
|
||||||
|
int32_t feat_dim = features.size() / num_frames;
|
||||||
|
|
||||||
|
const auto &meta_data = model_.GetMetaData();
|
||||||
|
if (!meta_data.feature_normalize_type.empty()) {
|
||||||
|
if (meta_data.feature_normalize_type == "per_feature") {
|
||||||
|
NormalizePerFeature(features.data(), num_frames, feat_dim);
|
||||||
|
} else {
|
||||||
|
SHERPA_ONNX_LOGE("Unsupported feature_normalize_type: %s",
|
||||||
|
meta_data.feature_normalize_type.c_str());
|
||||||
|
exit(-1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (num_frames % 16 != 0) {
|
||||||
|
int32_t pad = 16 - num_frames % 16;
|
||||||
|
features.resize((num_frames + pad) * feat_dim);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto memory_info =
|
||||||
|
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
||||||
|
|
||||||
|
std::array<int64_t, 3> x_shape{1, num_frames, feat_dim};
|
||||||
|
Ort::Value x =
|
||||||
|
Ort::Value::CreateTensor(memory_info, features.data(), features.size(),
|
||||||
|
x_shape.data(), x_shape.size());
|
||||||
|
|
||||||
|
x = Transpose12(model_.Allocator(), &x);
|
||||||
|
|
||||||
|
int64_t x_lens = num_frames;
|
||||||
|
std::array<int64_t, 1> x_lens_shape{1};
|
||||||
|
Ort::Value x_lens_tensor = Ort::Value::CreateTensor(
|
||||||
|
memory_info, &x_lens, 1, x_lens_shape.data(), x_lens_shape.size());
|
||||||
|
|
||||||
|
Ort::Value embedding =
|
||||||
|
model_.Compute(std::move(x), std::move(x_lens_tensor));
|
||||||
|
std::vector<int64_t> embedding_shape =
|
||||||
|
embedding.GetTensorTypeAndShapeInfo().GetShape();
|
||||||
|
|
||||||
|
std::vector<float> ans(embedding_shape[1]);
|
||||||
|
std::copy(embedding.GetTensorData<float>(),
|
||||||
|
embedding.GetTensorData<float>() + ans.size(), ans.begin());
|
||||||
|
|
||||||
|
return ans;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
void NormalizePerFeature(float *p, int32_t num_frames,
|
||||||
|
int32_t feat_dim) const {
|
||||||
|
auto m = Eigen::Map<
|
||||||
|
Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>(
|
||||||
|
p, num_frames, feat_dim);
|
||||||
|
|
||||||
|
auto EX = m.colwise().mean();
|
||||||
|
auto EX2 = m.array().pow(2).colwise().sum() / num_frames;
|
||||||
|
auto variance = EX2 - EX.array().pow(2);
|
||||||
|
auto stddev = variance.array().sqrt();
|
||||||
|
|
||||||
|
m = (m.rowwise() - EX).array().rowwise() / stddev.array();
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
SpeakerEmbeddingExtractorNeMoModel model_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
|
|
||||||
|
#endif // SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_NEMO_IMPL_H_
|
||||||
@@ -0,0 +1,28 @@
|
|||||||
|
// sherpa-onnx/csrc/speaker-embedding-extractor-nemo-model-meta-data.h
|
||||||
|
//
|
||||||
|
// Copyright (c) 2024 Xiaomi Corporation
|
||||||
|
#ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_NEMO_MODEL_META_DATA_H_
|
||||||
|
#define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_NEMO_MODEL_META_DATA_H_
|
||||||
|
|
||||||
|
#include <cstdint>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
struct SpeakerEmbeddingExtractorNeMoModelMetaData {
|
||||||
|
int32_t output_dim = 0;
|
||||||
|
int32_t feat_dim = 80;
|
||||||
|
int32_t sample_rate = 0;
|
||||||
|
int32_t window_size_ms = 25;
|
||||||
|
int32_t window_stride_ms = 25;
|
||||||
|
|
||||||
|
// Chinese, English, etc.
|
||||||
|
std::string language;
|
||||||
|
|
||||||
|
// for 3d-speaker, it is global-mean
|
||||||
|
std::string feature_normalize_type;
|
||||||
|
std::string window_type = "hann";
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
|
#endif // SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_NEMO_MODEL_META_DATA_H_
|
||||||
126
sherpa-onnx/csrc/speaker-embedding-extractor-nemo-model.cc
Normal file
126
sherpa-onnx/csrc/speaker-embedding-extractor-nemo-model.cc
Normal file
@@ -0,0 +1,126 @@
|
|||||||
|
// sherpa-onnx/csrc/speaker-embedding-extractor-nemo-model.cc
|
||||||
|
//
|
||||||
|
// Copyright (c) 2024 Xiaomi Corporation
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/speaker-embedding-extractor-nemo-model.h"
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "sherpa-onnx/csrc/macros.h"
|
||||||
|
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||||
|
#include "sherpa-onnx/csrc/session.h"
|
||||||
|
#include "sherpa-onnx/csrc/speaker-embedding-extractor-nemo-model-meta-data.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
class SpeakerEmbeddingExtractorNeMoModel::Impl {
|
||||||
|
public:
|
||||||
|
explicit Impl(const SpeakerEmbeddingExtractorConfig &config)
|
||||||
|
: config_(config),
|
||||||
|
env_(ORT_LOGGING_LEVEL_ERROR),
|
||||||
|
sess_opts_(GetSessionOptions(config)),
|
||||||
|
allocator_{} {
|
||||||
|
{
|
||||||
|
auto buf = ReadFile(config.model);
|
||||||
|
Init(buf.data(), buf.size());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ort::Value Compute(Ort::Value x, Ort::Value x_lens) const {
|
||||||
|
std::array<Ort::Value, 2> inputs = {std::move(x), std::move(x_lens)};
|
||||||
|
|
||||||
|
// output_names_ptr_[0] is logits
|
||||||
|
// output_names_ptr_[1] is embeddings
|
||||||
|
// so we use output_names_ptr_.data() + 1 here to extract only the
|
||||||
|
// embeddings
|
||||||
|
auto outputs = sess_->Run({}, input_names_ptr_.data(), inputs.data(),
|
||||||
|
inputs.size(), output_names_ptr_.data() + 1, 1);
|
||||||
|
return std::move(outputs[0]);
|
||||||
|
}
|
||||||
|
|
||||||
|
OrtAllocator *Allocator() const { return allocator_; }
|
||||||
|
|
||||||
|
const SpeakerEmbeddingExtractorNeMoModelMetaData &GetMetaData() const {
|
||||||
|
return meta_data_;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
void Init(void *model_data, size_t model_data_length) {
|
||||||
|
sess_ = std::make_unique<Ort::Session>(env_, model_data, model_data_length,
|
||||||
|
sess_opts_);
|
||||||
|
|
||||||
|
GetInputNames(sess_.get(), &input_names_, &input_names_ptr_);
|
||||||
|
|
||||||
|
GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_);
|
||||||
|
|
||||||
|
// get meta data
|
||||||
|
Ort::ModelMetadata meta_data = sess_->GetModelMetadata();
|
||||||
|
if (config_.debug) {
|
||||||
|
std::ostringstream os;
|
||||||
|
PrintModelMetadata(os, meta_data);
|
||||||
|
SHERPA_ONNX_LOGE("%s", os.str().c_str());
|
||||||
|
}
|
||||||
|
|
||||||
|
Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
|
||||||
|
SHERPA_ONNX_READ_META_DATA(meta_data_.output_dim, "output_dim");
|
||||||
|
SHERPA_ONNX_READ_META_DATA(meta_data_.feat_dim, "feat_dim");
|
||||||
|
SHERPA_ONNX_READ_META_DATA(meta_data_.sample_rate, "sample_rate");
|
||||||
|
SHERPA_ONNX_READ_META_DATA(meta_data_.window_size_ms, "window_size_ms");
|
||||||
|
SHERPA_ONNX_READ_META_DATA(meta_data_.window_stride_ms, "window_stride_ms");
|
||||||
|
SHERPA_ONNX_READ_META_DATA_STR(meta_data_.language, "language");
|
||||||
|
|
||||||
|
SHERPA_ONNX_READ_META_DATA_STR_WITH_DEFAULT(
|
||||||
|
meta_data_.feature_normalize_type, "feature_normalize_type", "");
|
||||||
|
|
||||||
|
SHERPA_ONNX_READ_META_DATA_STR_WITH_DEFAULT(meta_data_.window_type,
|
||||||
|
"window_type", "povey");
|
||||||
|
|
||||||
|
std::string framework;
|
||||||
|
SHERPA_ONNX_READ_META_DATA_STR(framework, "framework");
|
||||||
|
if (framework != "nemo") {
|
||||||
|
SHERPA_ONNX_LOGE("Expect a NeMo model, given: %s", framework.c_str());
|
||||||
|
exit(-1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
SpeakerEmbeddingExtractorConfig config_;
|
||||||
|
Ort::Env env_;
|
||||||
|
Ort::SessionOptions sess_opts_;
|
||||||
|
Ort::AllocatorWithDefaultOptions allocator_;
|
||||||
|
|
||||||
|
std::unique_ptr<Ort::Session> sess_;
|
||||||
|
|
||||||
|
std::vector<std::string> input_names_;
|
||||||
|
std::vector<const char *> input_names_ptr_;
|
||||||
|
|
||||||
|
std::vector<std::string> output_names_;
|
||||||
|
std::vector<const char *> output_names_ptr_;
|
||||||
|
|
||||||
|
SpeakerEmbeddingExtractorNeMoModelMetaData meta_data_;
|
||||||
|
};
|
||||||
|
|
||||||
|
SpeakerEmbeddingExtractorNeMoModel::SpeakerEmbeddingExtractorNeMoModel(
|
||||||
|
const SpeakerEmbeddingExtractorConfig &config)
|
||||||
|
: impl_(std::make_unique<Impl>(config)) {}
|
||||||
|
|
||||||
|
SpeakerEmbeddingExtractorNeMoModel::~SpeakerEmbeddingExtractorNeMoModel() =
|
||||||
|
default;
|
||||||
|
|
||||||
|
const SpeakerEmbeddingExtractorNeMoModelMetaData &
|
||||||
|
SpeakerEmbeddingExtractorNeMoModel::GetMetaData() const {
|
||||||
|
return impl_->GetMetaData();
|
||||||
|
}
|
||||||
|
|
||||||
|
Ort::Value SpeakerEmbeddingExtractorNeMoModel::Compute(
|
||||||
|
Ort::Value x, Ort::Value x_lens) const {
|
||||||
|
return impl_->Compute(std::move(x), std::move(x_lens));
|
||||||
|
}
|
||||||
|
|
||||||
|
OrtAllocator *SpeakerEmbeddingExtractorNeMoModel::Allocator() const {
|
||||||
|
return impl_->Allocator();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
40
sherpa-onnx/csrc/speaker-embedding-extractor-nemo-model.h
Normal file
40
sherpa-onnx/csrc/speaker-embedding-extractor-nemo-model.h
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
// sherpa-onnx/csrc/speaker-embedding-extractor-nemo-model.h
|
||||||
|
//
|
||||||
|
// Copyright (c) 2024 Xiaomi Corporation
|
||||||
|
#ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_NEMO_MODEL_H_
|
||||||
|
#define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_NEMO_MODEL_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
|
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||||
|
#include "sherpa-onnx/csrc/speaker-embedding-extractor-nemo-model-meta-data.h"
|
||||||
|
#include "sherpa-onnx/csrc/speaker-embedding-extractor.h"
|
||||||
|
|
||||||
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
|
class SpeakerEmbeddingExtractorNeMoModel {
|
||||||
|
public:
|
||||||
|
explicit SpeakerEmbeddingExtractorNeMoModel(
|
||||||
|
const SpeakerEmbeddingExtractorConfig &config);
|
||||||
|
|
||||||
|
~SpeakerEmbeddingExtractorNeMoModel();
|
||||||
|
|
||||||
|
const SpeakerEmbeddingExtractorNeMoModelMetaData &GetMetaData() const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @param x A float32 tensor of shape (N, C, T)
|
||||||
|
* @param x_len A int64 tensor of shape (N,)
|
||||||
|
* @return A float32 tensor of shape (N, C)
|
||||||
|
*/
|
||||||
|
Ort::Value Compute(Ort::Value x, Ort::Value x_len) const;
|
||||||
|
|
||||||
|
OrtAllocator *Allocator() const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
class Impl;
|
||||||
|
std::unique_ptr<Impl> impl_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace sherpa_onnx
|
||||||
|
|
||||||
|
#endif // SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_NEMO_MODEL_H_
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
// sherpa-onnx/csrc/speaker-embedding-extractor.cc
|
// sherpa-onnx/csrc/speaker-embedding-extractor.cc
|
||||||
//
|
//
|
||||||
// Copyright (c) 2023 Xiaomi Corporation
|
// Copyright (c) 2024 Xiaomi Corporation
|
||||||
|
|
||||||
#include "sherpa-onnx/csrc/speaker-embedding-extractor.h"
|
#include "sherpa-onnx/csrc/speaker-embedding-extractor.h"
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
// sherpa-onnx/csrc/speaker-embedding-extractor.h
|
// sherpa-onnx/csrc/speaker-embedding-extractor.h
|
||||||
//
|
//
|
||||||
// Copyright (c) 2023 Xiaomi Corporation
|
// Copyright (c) 2024 Xiaomi Corporation
|
||||||
|
|
||||||
#ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_H_
|
#ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_H_
|
||||||
#define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_H_
|
#define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_H_
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
// sherpa-onnx/csrc/speaker-embedding-manager-test.cc
|
// sherpa-onnx/csrc/speaker-embedding-manager-test.cc
|
||||||
//
|
//
|
||||||
// Copyright (c) 2023 Jingzhao Ou (jingzhao.ou@gmail.com)
|
// Copyright (c) 2024 Jingzhao Ou (jingzhao.ou@gmail.com)
|
||||||
|
|
||||||
#include "sherpa-onnx/csrc/speaker-embedding-manager.h"
|
#include "sherpa-onnx/csrc/speaker-embedding-manager.h"
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
// sherpa-onnx/csrc/speaker-embedding-manager.cc
|
// sherpa-onnx/csrc/speaker-embedding-manager.cc
|
||||||
//
|
//
|
||||||
// Copyright (c) 2023 Xiaomi Corporation
|
// Copyright (c) 2024 Xiaomi Corporation
|
||||||
|
|
||||||
#include "sherpa-onnx/csrc/speaker-embedding-manager.h"
|
#include "sherpa-onnx/csrc/speaker-embedding-manager.h"
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
// sherpa-onnx/csrc/speaker-embedding-manager.h
|
// sherpa-onnx/csrc/speaker-embedding-manager.h
|
||||||
//
|
//
|
||||||
// Copyright (c) 2023 Xiaomi Corporation
|
// Copyright (c) 2024 Xiaomi Corporation
|
||||||
|
|
||||||
#ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_MANAGER_H_
|
#ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_MANAGER_H_
|
||||||
#define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_MANAGER_H_
|
#define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_MANAGER_H_
|
||||||
|
|||||||
@@ -56,7 +56,7 @@ def load_speaker_embedding_model(model_filename):
|
|||||||
return extractor
|
return extractor
|
||||||
|
|
||||||
|
|
||||||
def test_wespeaker_model(model_filename: str):
|
def test_zh_models(model_filename: str):
|
||||||
model_filename = str(model_filename)
|
model_filename = str(model_filename)
|
||||||
if "en" in model_filename:
|
if "en" in model_filename:
|
||||||
print(f"skip {model_filename}")
|
print(f"skip {model_filename}")
|
||||||
@@ -114,8 +114,9 @@ def test_wespeaker_model(model_filename: str):
|
|||||||
assert ans == name, (name, ans)
|
assert ans == name, (name, ans)
|
||||||
|
|
||||||
|
|
||||||
def test_3dspeaker_model(model_filename: str):
|
def test_en_and_zh_models(model_filename: str):
|
||||||
extractor = load_speaker_embedding_model(str(model_filename))
|
model_filename = str(model_filename)
|
||||||
|
extractor = load_speaker_embedding_model(model_filename)
|
||||||
manager = sherpa_onnx.SpeakerEmbeddingManager(extractor.dim)
|
manager = sherpa_onnx.SpeakerEmbeddingManager(extractor.dim)
|
||||||
|
|
||||||
filenames = [
|
filenames = [
|
||||||
@@ -124,7 +125,14 @@ def test_3dspeaker_model(model_filename: str):
|
|||||||
"speaker1_a_en_16k",
|
"speaker1_a_en_16k",
|
||||||
"speaker2_a_en_16k",
|
"speaker2_a_en_16k",
|
||||||
]
|
]
|
||||||
|
is_en = "en" in model_filename
|
||||||
for filename in filenames:
|
for filename in filenames:
|
||||||
|
if is_en and "cn" in filename:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not is_en and "en" in filename:
|
||||||
|
continue
|
||||||
|
|
||||||
name = filename.rsplit("_", maxsplit=1)[0]
|
name = filename.rsplit("_", maxsplit=1)[0]
|
||||||
data, sample_rate = read_wave(
|
data, sample_rate = read_wave(
|
||||||
f"/tmp/sr-models/sr-data/test/3d-speaker/{filename}.wav"
|
f"/tmp/sr-models/sr-data/test/3d-speaker/{filename}.wav"
|
||||||
@@ -145,6 +153,11 @@ def test_3dspeaker_model(model_filename: str):
|
|||||||
"speaker1_b_en_16k",
|
"speaker1_b_en_16k",
|
||||||
]
|
]
|
||||||
for filename in filenames:
|
for filename in filenames:
|
||||||
|
if is_en and "cn" in filename:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not is_en and "en" in filename:
|
||||||
|
continue
|
||||||
print(filename)
|
print(filename)
|
||||||
name = filename.rsplit("_", maxsplit=1)[0]
|
name = filename.rsplit("_", maxsplit=1)[0]
|
||||||
name = name.replace("b_cn", "a_cn")
|
name = name.replace("b_cn", "a_cn")
|
||||||
@@ -178,7 +191,8 @@ class TestSpeakerRecognition(unittest.TestCase):
|
|||||||
return
|
return
|
||||||
for filename in model_dir.glob("*.onnx"):
|
for filename in model_dir.glob("*.onnx"):
|
||||||
print(filename)
|
print(filename)
|
||||||
test_wespeaker_model(filename)
|
test_zh_models(filename)
|
||||||
|
test_en_and_zh_models(filename)
|
||||||
|
|
||||||
def test_3dpeaker_models(self):
|
def test_3dpeaker_models(self):
|
||||||
model_dir = Path(d) / "3dspeaker"
|
model_dir = Path(d) / "3dspeaker"
|
||||||
@@ -187,7 +201,16 @@ class TestSpeakerRecognition(unittest.TestCase):
|
|||||||
return
|
return
|
||||||
for filename in model_dir.glob("*.onnx"):
|
for filename in model_dir.glob("*.onnx"):
|
||||||
print(filename)
|
print(filename)
|
||||||
test_3dspeaker_model(filename)
|
test_en_and_zh_models(filename)
|
||||||
|
|
||||||
|
def test_nemo_models(self):
|
||||||
|
model_dir = Path(d) / "nemo"
|
||||||
|
if not model_dir.is_dir():
|
||||||
|
print(f"{model_dir} does not exist - skip it")
|
||||||
|
return
|
||||||
|
for filename in model_dir.glob("*.onnx"):
|
||||||
|
print(filename)
|
||||||
|
test_en_and_zh_models(filename)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
Reference in New Issue
Block a user