Add runtime support for wespeaker models (#516)
This commit is contained in:
@@ -96,6 +96,14 @@ set(sources
|
||||
wave-reader.cc
|
||||
)
|
||||
|
||||
# speaker embedding extractor
|
||||
list(APPEND sources
|
||||
speaker-embedding-extractor-impl.cc
|
||||
speaker-embedding-extractor-wespeaker-model.cc
|
||||
speaker-embedding-extractor.cc
|
||||
speaker-embedding-manager.cc
|
||||
)
|
||||
|
||||
list(APPEND sources
|
||||
lexicon.cc
|
||||
offline-tts-impl.cc
|
||||
@@ -387,6 +395,10 @@ if(SHERPA_ONNX_ENABLE_TESTS)
|
||||
utfcpp-test.cc
|
||||
)
|
||||
|
||||
list(APPEND sherpa_onnx_test_srcs
|
||||
speaker-embedding-manager-test.cc
|
||||
)
|
||||
|
||||
function(sherpa_onnx_add_test source)
|
||||
get_filename_component(name ${source} NAME_WE)
|
||||
set(target_name ${name})
|
||||
|
||||
@@ -64,8 +64,8 @@ TEST(ContextGraph, Benchmark) {
|
||||
auto stop = std::chrono::high_resolution_clock::now();
|
||||
auto duration =
|
||||
std::chrono::duration_cast<std::chrono::microseconds>(stop - start);
|
||||
SHERPA_ONNX_LOGE("Construct context graph for %d item takes %ld us.", num,
|
||||
duration.count());
|
||||
SHERPA_ONNX_LOGE("Construct context graph for %d item takes %d us.", num,
|
||||
static_cast<int32_t>(duration.count()));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -91,4 +91,9 @@ Ort::SessionOptions GetSessionOptions(const OfflineTtsModelConfig &config) {
|
||||
return GetSessionOptionsImpl(config.num_threads, config.provider);
|
||||
}
|
||||
|
||||
Ort::SessionOptions GetSessionOptions(
|
||||
const SpeakerEmbeddingExtractorConfig &config) {
|
||||
return GetSessionOptionsImpl(config.num_threads, config.provider);
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
#include "sherpa-onnx/csrc/offline-tts-model-config.h"
|
||||
#include "sherpa-onnx/csrc/online-lm-config.h"
|
||||
#include "sherpa-onnx/csrc/online-model-config.h"
|
||||
#include "sherpa-onnx/csrc/speaker-embedding-extractor.h"
|
||||
#include "sherpa-onnx/csrc/vad-model-config.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
@@ -26,6 +27,9 @@ Ort::SessionOptions GetSessionOptions(const OnlineLMConfig &config);
|
||||
Ort::SessionOptions GetSessionOptions(const VadModelConfig &config);
|
||||
|
||||
Ort::SessionOptions GetSessionOptions(const OfflineTtsModelConfig &config);
|
||||
|
||||
Ort::SessionOptions GetSessionOptions(
|
||||
const SpeakerEmbeddingExtractorConfig &config);
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_SESSION_H_
|
||||
|
||||
82
sherpa-onnx/csrc/speaker-embedding-extractor-impl.cc
Normal file
82
sherpa-onnx/csrc/speaker-embedding-extractor-impl.cc
Normal file
@@ -0,0 +1,82 @@
|
||||
// sherpa-onnx/csrc/speaker-embedding-extractor-impl.cc
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
#include "sherpa-onnx/csrc/speaker-embedding-extractor-impl.h"
|
||||
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
#include "sherpa-onnx/csrc/speaker-embedding-extractor-wespeaker-impl.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
namespace {
|
||||
|
||||
enum class ModelType {
|
||||
kWeSpeaker,
|
||||
kUnkown,
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
static ModelType GetModelType(char *model_data, size_t model_data_length,
|
||||
bool debug) {
|
||||
Ort::Env env(ORT_LOGGING_LEVEL_WARNING);
|
||||
Ort::SessionOptions sess_opts;
|
||||
|
||||
auto sess = std::make_unique<Ort::Session>(env, model_data, model_data_length,
|
||||
sess_opts);
|
||||
|
||||
Ort::ModelMetadata meta_data = sess->GetModelMetadata();
|
||||
if (debug) {
|
||||
std::ostringstream os;
|
||||
PrintModelMetadata(os, meta_data);
|
||||
SHERPA_ONNX_LOGE("%s", os.str().c_str());
|
||||
}
|
||||
|
||||
Ort::AllocatorWithDefaultOptions allocator;
|
||||
auto model_type =
|
||||
meta_data.LookupCustomMetadataMapAllocated("framework", allocator);
|
||||
if (!model_type) {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"No model_type in the metadata!\n"
|
||||
"Please make sure you have added metadata to the model.\n\n"
|
||||
"For instance, you can use\n"
|
||||
"https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/wespeaker/"
|
||||
"add_meta_data.py"
|
||||
"to add metadata to models from WeSpeaker\n");
|
||||
return ModelType::kUnkown;
|
||||
}
|
||||
|
||||
if (model_type.get() == std::string("wespeaker")) {
|
||||
return ModelType::kWeSpeaker;
|
||||
} else {
|
||||
SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get());
|
||||
return ModelType::kUnkown;
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_ptr<SpeakerEmbeddingExtractorImpl>
|
||||
SpeakerEmbeddingExtractorImpl::Create(
|
||||
const SpeakerEmbeddingExtractorConfig &config) {
|
||||
ModelType model_type = ModelType::kUnkown;
|
||||
|
||||
{
|
||||
auto buffer = ReadFile(config.model);
|
||||
|
||||
model_type = GetModelType(buffer.data(), buffer.size(), config.debug);
|
||||
}
|
||||
|
||||
switch (model_type) {
|
||||
case ModelType::kWeSpeaker:
|
||||
return std::make_unique<SpeakerEmbeddingExtractorWeSpeakerImpl>(config);
|
||||
case ModelType::kUnkown:
|
||||
SHERPA_ONNX_LOGE(
|
||||
"Unknown model type in for speaker embedding extractor!");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// unreachable code
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
34
sherpa-onnx/csrc/speaker-embedding-extractor-impl.h
Normal file
34
sherpa-onnx/csrc/speaker-embedding-extractor-impl.h
Normal file
@@ -0,0 +1,34 @@
|
||||
// sherpa-onnx/csrc/speaker-embedding-extractor-impl.h
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
#ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_IMPL_H_
|
||||
#define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_IMPL_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/speaker-embedding-extractor.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
class SpeakerEmbeddingExtractorImpl {
|
||||
public:
|
||||
virtual ~SpeakerEmbeddingExtractorImpl() = default;
|
||||
|
||||
static std::unique_ptr<SpeakerEmbeddingExtractorImpl> Create(
|
||||
const SpeakerEmbeddingExtractorConfig &config);
|
||||
|
||||
virtual int32_t Dim() const = 0;
|
||||
|
||||
virtual std::unique_ptr<OnlineStream> CreateStream() const = 0;
|
||||
|
||||
virtual bool IsReady(OnlineStream *s) const = 0;
|
||||
|
||||
virtual std::vector<float> Compute(OnlineStream *s) const = 0;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_IMPL_H_
|
||||
@@ -0,0 +1,79 @@
|
||||
// sherpa-onnx/csrc/speaker-embedding-extractor-wespeaker-impl.h
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
#ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_WESPEAKER_IMPL_H_
|
||||
#define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_WESPEAKER_IMPL_H_
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/speaker-embedding-extractor-impl.h"
|
||||
#include "sherpa-onnx/csrc/speaker-embedding-extractor-wespeaker-model.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
class SpeakerEmbeddingExtractorWeSpeakerImpl
|
||||
: public SpeakerEmbeddingExtractorImpl {
|
||||
public:
|
||||
explicit SpeakerEmbeddingExtractorWeSpeakerImpl(
|
||||
const SpeakerEmbeddingExtractorConfig &config)
|
||||
: model_(config) {}
|
||||
|
||||
int32_t Dim() const override { return model_.GetMetaData().output_dim; }
|
||||
|
||||
std::unique_ptr<OnlineStream> CreateStream() const override {
|
||||
FeatureExtractorConfig feat_config;
|
||||
auto meta_data = model_.GetMetaData();
|
||||
feat_config.sampling_rate = meta_data.sample_rate;
|
||||
feat_config.normalize_samples = meta_data.normalize_features;
|
||||
|
||||
return std::make_unique<OnlineStream>(feat_config);
|
||||
}
|
||||
|
||||
bool IsReady(OnlineStream *s) const override {
|
||||
return s->GetNumProcessedFrames() < s->NumFramesReady();
|
||||
}
|
||||
|
||||
std::vector<float> Compute(OnlineStream *s) const override {
|
||||
int32_t num_frames = s->NumFramesReady() - s->GetNumProcessedFrames();
|
||||
if (num_frames <= 0) {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"Please make sure IsReady(s) returns true. num_frames: %d",
|
||||
num_frames);
|
||||
return {};
|
||||
}
|
||||
|
||||
std::vector<float> features =
|
||||
s->GetFrames(s->GetNumProcessedFrames(), num_frames);
|
||||
|
||||
s->GetNumProcessedFrames() += num_frames;
|
||||
|
||||
int32_t feat_dim = features.size() / num_frames;
|
||||
|
||||
auto memory_info =
|
||||
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
||||
|
||||
std::array<int64_t, 3> x_shape{1, num_frames, feat_dim};
|
||||
Ort::Value x =
|
||||
Ort::Value::CreateTensor(memory_info, features.data(), features.size(),
|
||||
x_shape.data(), x_shape.size());
|
||||
Ort::Value embedding = model_.Compute(std::move(x));
|
||||
std::vector<int64_t> embedding_shape =
|
||||
embedding.GetTensorTypeAndShapeInfo().GetShape();
|
||||
|
||||
std::vector<float> ans(embedding_shape[1]);
|
||||
std::copy(embedding.GetTensorData<float>(),
|
||||
embedding.GetTensorData<float>() + ans.size(), ans.begin());
|
||||
|
||||
return ans;
|
||||
}
|
||||
|
||||
private:
|
||||
SpeakerEmbeddingExtractorWeSpeakerModel model_;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_WESPEAKER_IMPL_H_
|
||||
@@ -0,0 +1,20 @@
|
||||
// sherpa-onnx/csrc/speaker-embedding-extractor-wespeaker-model-metadata.h
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
#ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_WESPEAKER_MODEL_METADATA_H_
|
||||
#define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_WESPEAKER_MODEL_METADATA_H_
|
||||
|
||||
#include <cstdint>
|
||||
#include <string>
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
struct SpeakerEmbeddingExtractorWeSpeakerModelMetaData {
|
||||
int32_t output_dim = 0;
|
||||
int32_t sample_rate = 0;
|
||||
int32_t normalize_features = 0;
|
||||
std::string language;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
#endif // SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_WESPEAKER_MODEL_METADATA_H_
|
||||
112
sherpa-onnx/csrc/speaker-embedding-extractor-wespeaker-model.cc
Normal file
112
sherpa-onnx/csrc/speaker-embedding-extractor-wespeaker-model.cc
Normal file
@@ -0,0 +1,112 @@
|
||||
// sherpa-onnx/csrc/speaker-embedding-extractor-wespeaker-model.cc
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/speaker-embedding-extractor-wespeaker-model.h"
|
||||
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
#include "sherpa-onnx/csrc/session.h"
|
||||
#include "sherpa-onnx/csrc/speaker-embedding-extractor-wespeaker-model-metadata.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
class SpeakerEmbeddingExtractorWeSpeakerModel::Impl {
|
||||
public:
|
||||
explicit Impl(const SpeakerEmbeddingExtractorConfig &config)
|
||||
: config_(config),
|
||||
env_(ORT_LOGGING_LEVEL_ERROR),
|
||||
sess_opts_(GetSessionOptions(config)),
|
||||
allocator_{} {
|
||||
{
|
||||
auto buf = ReadFile(config.model);
|
||||
Init(buf.data(), buf.size());
|
||||
}
|
||||
}
|
||||
|
||||
Ort::Value Compute(Ort::Value x) const {
|
||||
std::array<Ort::Value, 1> inputs = {std::move(x)};
|
||||
|
||||
auto outputs =
|
||||
sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(),
|
||||
output_names_ptr_.data(), output_names_ptr_.size());
|
||||
return std::move(outputs[0]);
|
||||
}
|
||||
|
||||
const SpeakerEmbeddingExtractorWeSpeakerModelMetaData &GetMetaData() const {
|
||||
return meta_data_;
|
||||
}
|
||||
|
||||
private:
|
||||
void Init(void *model_data, size_t model_data_length) {
|
||||
sess_ = std::make_unique<Ort::Session>(env_, model_data, model_data_length,
|
||||
sess_opts_);
|
||||
|
||||
GetInputNames(sess_.get(), &input_names_, &input_names_ptr_);
|
||||
|
||||
GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_);
|
||||
|
||||
// get meta data
|
||||
Ort::ModelMetadata meta_data = sess_->GetModelMetadata();
|
||||
if (config_.debug) {
|
||||
std::ostringstream os;
|
||||
PrintModelMetadata(os, meta_data);
|
||||
SHERPA_ONNX_LOGE("%s", os.str().c_str());
|
||||
}
|
||||
|
||||
Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
|
||||
SHERPA_ONNX_READ_META_DATA(meta_data_.output_dim, "output_dim");
|
||||
SHERPA_ONNX_READ_META_DATA(meta_data_.sample_rate, "sample_rate");
|
||||
SHERPA_ONNX_READ_META_DATA(meta_data_.normalize_features,
|
||||
"normalize_features");
|
||||
SHERPA_ONNX_READ_META_DATA_STR(meta_data_.language, "language");
|
||||
|
||||
std::string framework;
|
||||
SHERPA_ONNX_READ_META_DATA_STR(framework, "framework");
|
||||
if (framework != "wespeaker") {
|
||||
SHERPA_ONNX_LOGE("Expect a wespeaker model, given: %s",
|
||||
framework.c_str());
|
||||
exit(-1);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
SpeakerEmbeddingExtractorConfig config_;
|
||||
Ort::Env env_;
|
||||
Ort::SessionOptions sess_opts_;
|
||||
Ort::AllocatorWithDefaultOptions allocator_;
|
||||
|
||||
std::unique_ptr<Ort::Session> sess_;
|
||||
|
||||
std::vector<std::string> input_names_;
|
||||
std::vector<const char *> input_names_ptr_;
|
||||
|
||||
std::vector<std::string> output_names_;
|
||||
std::vector<const char *> output_names_ptr_;
|
||||
|
||||
SpeakerEmbeddingExtractorWeSpeakerModelMetaData meta_data_;
|
||||
};
|
||||
|
||||
SpeakerEmbeddingExtractorWeSpeakerModel::
|
||||
SpeakerEmbeddingExtractorWeSpeakerModel(
|
||||
const SpeakerEmbeddingExtractorConfig &config)
|
||||
: impl_(std::make_unique<Impl>(config)) {}
|
||||
|
||||
SpeakerEmbeddingExtractorWeSpeakerModel::
|
||||
~SpeakerEmbeddingExtractorWeSpeakerModel() = default;
|
||||
|
||||
const SpeakerEmbeddingExtractorWeSpeakerModelMetaData &
|
||||
SpeakerEmbeddingExtractorWeSpeakerModel::GetMetaData() const {
|
||||
return impl_->GetMetaData();
|
||||
}
|
||||
|
||||
Ort::Value SpeakerEmbeddingExtractorWeSpeakerModel::Compute(
|
||||
Ort::Value x) const {
|
||||
return impl_->Compute(std::move(x));
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
@@ -0,0 +1,37 @@
|
||||
// sherpa-onnx/csrc/speaker-embedding-extractor-wespeaker-model.h
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
#ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_WESPEAKER_MODEL_H_
|
||||
#define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_WESPEAKER_MODEL_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||
#include "sherpa-onnx/csrc/speaker-embedding-extractor-wespeaker-model-metadata.h"
|
||||
#include "sherpa-onnx/csrc/speaker-embedding-extractor.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
class SpeakerEmbeddingExtractorWeSpeakerModel {
|
||||
public:
|
||||
explicit SpeakerEmbeddingExtractorWeSpeakerModel(
|
||||
const SpeakerEmbeddingExtractorConfig &config);
|
||||
|
||||
~SpeakerEmbeddingExtractorWeSpeakerModel();
|
||||
|
||||
const SpeakerEmbeddingExtractorWeSpeakerModelMetaData &GetMetaData() const;
|
||||
|
||||
/**
|
||||
* @param x A float32 tensor of shape (N, T, C)
|
||||
* @return A float32 tensor of shape (N, C)
|
||||
*/
|
||||
Ort::Value Compute(Ort::Value x) const;
|
||||
|
||||
private:
|
||||
class Impl;
|
||||
std::unique_ptr<Impl> impl_;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_WESPEAKER_MODEL_H_
|
||||
74
sherpa-onnx/csrc/speaker-embedding-extractor.cc
Normal file
74
sherpa-onnx/csrc/speaker-embedding-extractor.cc
Normal file
@@ -0,0 +1,74 @@
|
||||
// sherpa-onnx/csrc/speaker-embedding-extractor.cc
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/speaker-embedding-extractor.h"
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/file-utils.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/speaker-embedding-extractor-impl.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
void SpeakerEmbeddingExtractorConfig::Register(ParseOptions *po) {
|
||||
po->Register("model", &model, "Path to the speaker embedding model.");
|
||||
po->Register("num-threads", &num_threads,
|
||||
"Number of threads to run the neural network");
|
||||
|
||||
po->Register("debug", &debug,
|
||||
"true to print model information while loading it.");
|
||||
|
||||
po->Register("provider", &provider,
|
||||
"Specify a provider to use: cpu, cuda, coreml");
|
||||
}
|
||||
|
||||
bool SpeakerEmbeddingExtractorConfig::Validate() const {
|
||||
if (model.empty()) {
|
||||
SHERPA_ONNX_LOGE("Please provide --speaker-embedding-model");
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!FileExists(model)) {
|
||||
SHERPA_ONNX_LOGE("--speaker-embedding-model: %s does not exist",
|
||||
model.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
std::string SpeakerEmbeddingExtractorConfig::ToString() const {
|
||||
std::ostringstream os;
|
||||
|
||||
os << "SpeakerEmbeddingExtractorConfig(";
|
||||
os << "model=\"" << model << "\", ";
|
||||
os << "num_threads=" << num_threads << ", ";
|
||||
os << "debug=" << (debug ? "True" : "False") << ", ";
|
||||
os << "provider=\"" << provider << "\")";
|
||||
|
||||
return os.str();
|
||||
}
|
||||
|
||||
SpeakerEmbeddingExtractor::SpeakerEmbeddingExtractor(
|
||||
const SpeakerEmbeddingExtractorConfig &config)
|
||||
: impl_(SpeakerEmbeddingExtractorImpl::Create(config)) {}
|
||||
|
||||
SpeakerEmbeddingExtractor::~SpeakerEmbeddingExtractor() = default;
|
||||
|
||||
int32_t SpeakerEmbeddingExtractor::Dim() const { return impl_->Dim(); }
|
||||
|
||||
std::unique_ptr<OnlineStream> SpeakerEmbeddingExtractor::CreateStream() const {
|
||||
return impl_->CreateStream();
|
||||
}
|
||||
|
||||
bool SpeakerEmbeddingExtractor::IsReady(OnlineStream *s) const {
|
||||
return impl_->IsReady(s);
|
||||
}
|
||||
|
||||
std::vector<float> SpeakerEmbeddingExtractor::Compute(OnlineStream *s) const {
|
||||
return impl_->Compute(s);
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
67
sherpa-onnx/csrc/speaker-embedding-extractor.h
Normal file
67
sherpa-onnx/csrc/speaker-embedding-extractor.h
Normal file
@@ -0,0 +1,67 @@
|
||||
// sherpa-onnx/csrc/speaker-embedding-extractor.h
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
#ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_H_
|
||||
#define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/online-stream.h"
|
||||
#include "sherpa-onnx/csrc/parse-options.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
struct SpeakerEmbeddingExtractorConfig {
|
||||
std::string model;
|
||||
int32_t num_threads = 1;
|
||||
bool debug = false;
|
||||
std::string provider = "cpu";
|
||||
|
||||
SpeakerEmbeddingExtractorConfig() = default;
|
||||
SpeakerEmbeddingExtractorConfig(const std::string &model, int32_t num_threads,
|
||||
bool debug, const std::string &provider)
|
||||
: model(model),
|
||||
num_threads(num_threads),
|
||||
debug(debug),
|
||||
provider(provider) {}
|
||||
|
||||
void Register(ParseOptions *po);
|
||||
bool Validate() const;
|
||||
std::string ToString() const;
|
||||
};
|
||||
|
||||
class SpeakerEmbeddingExtractorImpl;
|
||||
|
||||
class SpeakerEmbeddingExtractor {
|
||||
public:
|
||||
explicit SpeakerEmbeddingExtractor(
|
||||
const SpeakerEmbeddingExtractorConfig &config);
|
||||
|
||||
~SpeakerEmbeddingExtractor();
|
||||
|
||||
// Return the dimension of the embedding
|
||||
int32_t Dim() const;
|
||||
|
||||
// Create a stream to accept audio samples and compute features
|
||||
std::unique_ptr<OnlineStream> CreateStream() const;
|
||||
|
||||
// Return true if there are feature frames in OnlineStream that
|
||||
// can be used to compute embeddings.
|
||||
bool IsReady(OnlineStream *s) const;
|
||||
|
||||
// Compute the speaker embedding from the available unprocessed features
|
||||
// of the given stream
|
||||
//
|
||||
// You have to ensure IsReady(s) returns true before you call this method.
|
||||
std::vector<float> Compute(OnlineStream *s) const;
|
||||
|
||||
private:
|
||||
std::unique_ptr<SpeakerEmbeddingExtractorImpl> impl_;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_H_
|
||||
147
sherpa-onnx/csrc/speaker-embedding-manager-test.cc
Normal file
147
sherpa-onnx/csrc/speaker-embedding-manager-test.cc
Normal file
@@ -0,0 +1,147 @@
|
||||
// sherpa-onnx/csrc/speaker-embedding-manager-test.cc
|
||||
//
|
||||
// Copyright (c) 2023 Jingzhao Ou (jingzhao.ou@gmail.com)
|
||||
|
||||
#include "sherpa-onnx/csrc/speaker-embedding-manager.h"
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
TEST(SpeakerEmbeddingManager, AddAndRemove) {
|
||||
int32_t dim = 2;
|
||||
SpeakerEmbeddingManager manager(dim);
|
||||
std::vector<float> v = {0.1, 0.1};
|
||||
bool status = manager.Add("first", v.data());
|
||||
ASSERT_TRUE(status);
|
||||
ASSERT_EQ(manager.NumSpeakers(), 1);
|
||||
|
||||
// duplicate
|
||||
status = manager.Add("first", v.data());
|
||||
ASSERT_FALSE(status);
|
||||
ASSERT_EQ(manager.NumSpeakers(), 1);
|
||||
|
||||
// non-duplicate
|
||||
v = {0.1, 0.9};
|
||||
status = manager.Add("second", v.data());
|
||||
ASSERT_TRUE(status);
|
||||
ASSERT_EQ(manager.NumSpeakers(), 2);
|
||||
|
||||
// do not exist
|
||||
status = manager.Remove("third");
|
||||
ASSERT_FALSE(status);
|
||||
|
||||
status = manager.Remove("first");
|
||||
ASSERT_TRUE(status);
|
||||
ASSERT_EQ(manager.NumSpeakers(), 1);
|
||||
|
||||
v = {0.1, 0.1};
|
||||
status = manager.Add("first", v.data());
|
||||
ASSERT_TRUE(status);
|
||||
ASSERT_EQ(manager.NumSpeakers(), 2);
|
||||
|
||||
status = manager.Remove("first");
|
||||
ASSERT_TRUE(status);
|
||||
ASSERT_EQ(manager.NumSpeakers(), 1);
|
||||
|
||||
status = manager.Remove("second");
|
||||
ASSERT_TRUE(status);
|
||||
ASSERT_EQ(manager.NumSpeakers(), 0);
|
||||
}
|
||||
|
||||
TEST(SpeakerEmbeddingManager, Search) {
|
||||
int32_t dim = 2;
|
||||
SpeakerEmbeddingManager manager(dim);
|
||||
std::vector<float> v1 = {0.1, 0.1};
|
||||
std::vector<float> v2 = {0.1, 0.9};
|
||||
std::vector<float> v3 = {0.9, 0.1};
|
||||
bool status = manager.Add("first", v1.data());
|
||||
ASSERT_TRUE(status);
|
||||
|
||||
status = manager.Add("second", v2.data());
|
||||
ASSERT_TRUE(status);
|
||||
|
||||
status = manager.Add("third", v3.data());
|
||||
ASSERT_TRUE(status);
|
||||
|
||||
ASSERT_EQ(manager.NumSpeakers(), 3);
|
||||
|
||||
std::vector<float> v = {15, 16};
|
||||
float threshold = 0.9;
|
||||
|
||||
std::string name = manager.Search(v.data(), threshold);
|
||||
EXPECT_EQ(name, "first");
|
||||
|
||||
v = {2, 17};
|
||||
name = manager.Search(v.data(), threshold);
|
||||
EXPECT_EQ(name, "second");
|
||||
|
||||
v = {17, 2};
|
||||
name = manager.Search(v.data(), threshold);
|
||||
EXPECT_EQ(name, "third");
|
||||
|
||||
threshold = 0.9;
|
||||
v = {15, 16};
|
||||
status = manager.Remove("first");
|
||||
ASSERT_TRUE(status);
|
||||
name = manager.Search(v.data(), threshold);
|
||||
EXPECT_EQ(name, "");
|
||||
|
||||
v = {17, 2};
|
||||
status = manager.Remove("third");
|
||||
ASSERT_TRUE(status);
|
||||
name = manager.Search(v.data(), threshold);
|
||||
EXPECT_EQ(name, "");
|
||||
|
||||
v = {2, 17};
|
||||
status = manager.Remove("second");
|
||||
ASSERT_TRUE(status);
|
||||
name = manager.Search(v.data(), threshold);
|
||||
EXPECT_EQ(name, "");
|
||||
|
||||
ASSERT_EQ(manager.NumSpeakers(), 0);
|
||||
}
|
||||
|
||||
TEST(SpeakerEmbeddingManager, Verify) {
|
||||
int32_t dim = 2;
|
||||
SpeakerEmbeddingManager manager(dim);
|
||||
std::vector<float> v1 = {0.1, 0.1};
|
||||
std::vector<float> v2 = {0.1, 0.9};
|
||||
std::vector<float> v3 = {0.9, 0.1};
|
||||
bool status = manager.Add("first", v1.data());
|
||||
ASSERT_TRUE(status);
|
||||
|
||||
status = manager.Add("second", v2.data());
|
||||
ASSERT_TRUE(status);
|
||||
|
||||
status = manager.Add("third", v3.data());
|
||||
ASSERT_TRUE(status);
|
||||
|
||||
std::vector<float> v = {15, 16};
|
||||
float threshold = 0.9;
|
||||
|
||||
status = manager.Verify("first", v.data(), threshold);
|
||||
ASSERT_TRUE(status);
|
||||
|
||||
v = {2, 17};
|
||||
status = manager.Verify("first", v.data(), threshold);
|
||||
ASSERT_FALSE(status);
|
||||
|
||||
status = manager.Verify("second", v.data(), threshold);
|
||||
ASSERT_TRUE(status);
|
||||
|
||||
v = {17, 2};
|
||||
status = manager.Verify("first", v.data(), threshold);
|
||||
ASSERT_FALSE(status);
|
||||
|
||||
status = manager.Verify("second", v.data(), threshold);
|
||||
ASSERT_FALSE(status);
|
||||
|
||||
status = manager.Verify("third", v.data(), threshold);
|
||||
ASSERT_TRUE(status);
|
||||
|
||||
status = manager.Verify("fourth", v.data(), threshold);
|
||||
ASSERT_FALSE(status);
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
144
sherpa-onnx/csrc/speaker-embedding-manager.cc
Normal file
144
sherpa-onnx/csrc/speaker-embedding-manager.cc
Normal file
@@ -0,0 +1,144 @@
|
||||
// sherpa-onnx/csrc/speaker-embedding-manager.cc
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/speaker-embedding-manager.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "Eigen/Dense"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
using FloatMatrix =
|
||||
Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
|
||||
|
||||
class SpeakerEmbeddingManager::Impl {
|
||||
public:
|
||||
explicit Impl(int32_t dim) : dim_(dim) {}
|
||||
|
||||
bool Add(const std::string &name, const float *p) {
|
||||
if (name2row_.count(name)) {
|
||||
// a speaker with the same name already exists
|
||||
return false;
|
||||
}
|
||||
|
||||
embedding_matrix_.conservativeResize(embedding_matrix_.rows() + 1, dim_);
|
||||
|
||||
std::copy(p, p + dim_, &embedding_matrix_.bottomRows(1)(0, 0));
|
||||
|
||||
embedding_matrix_.bottomRows(1).normalize(); // inplace
|
||||
|
||||
name2row_[name] = embedding_matrix_.rows() - 1;
|
||||
row2name_[embedding_matrix_.rows() - 1] = name;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool Remove(const std::string &name) {
|
||||
if (!name2row_.count(name)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
int32_t row_idx = name2row_.at(name);
|
||||
|
||||
int32_t num_rows = embedding_matrix_.rows();
|
||||
|
||||
if (row_idx < num_rows - 1) {
|
||||
embedding_matrix_.block(row_idx, 0, num_rows - -1 - row_idx, dim_) =
|
||||
embedding_matrix_.bottomRows(num_rows - 1 - row_idx);
|
||||
}
|
||||
|
||||
embedding_matrix_.conservativeResize(num_rows - 1, dim_);
|
||||
for (auto &p : name2row_) {
|
||||
if (p.second > row_idx) {
|
||||
p.second -= 1;
|
||||
row2name_[p.second] = p.first;
|
||||
}
|
||||
}
|
||||
|
||||
name2row_.erase(name);
|
||||
row2name_.erase(num_rows - 1);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
std::string Search(const float *p, float threshold) {
|
||||
if (embedding_matrix_.rows() == 0) {
|
||||
return {};
|
||||
}
|
||||
|
||||
Eigen::VectorXf v =
|
||||
Eigen::Map<Eigen::VectorXf>(const_cast<float *>(p), dim_);
|
||||
v.normalize();
|
||||
|
||||
Eigen::VectorXf scores = embedding_matrix_ * v;
|
||||
|
||||
Eigen::VectorXf::Index max_index;
|
||||
float max_score = scores.maxCoeff(&max_index);
|
||||
if (max_score < threshold) {
|
||||
return {};
|
||||
}
|
||||
|
||||
return row2name_.at(max_index);
|
||||
}
|
||||
|
||||
bool Verify(const std::string &name, const float *p, float threshold) {
|
||||
if (!name2row_.count(name)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
int32_t row_idx = name2row_.at(name);
|
||||
|
||||
Eigen::VectorXf v =
|
||||
Eigen::Map<Eigen::VectorXf>(const_cast<float *>(p), dim_);
|
||||
v.normalize();
|
||||
|
||||
float score = embedding_matrix_.row(row_idx) * v;
|
||||
|
||||
if (score < threshold) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
int32_t NumSpeakers() const { return embedding_matrix_.rows(); }
|
||||
|
||||
private:
|
||||
int32_t dim_;
|
||||
FloatMatrix embedding_matrix_;
|
||||
std::unordered_map<std::string, int32_t> name2row_;
|
||||
std::unordered_map<int32_t, std::string> row2name_;
|
||||
};
|
||||
|
||||
SpeakerEmbeddingManager::SpeakerEmbeddingManager(int32_t dim)
|
||||
: impl_(std::make_unique<Impl>(dim)) {}
|
||||
|
||||
SpeakerEmbeddingManager::~SpeakerEmbeddingManager() = default;
|
||||
|
||||
bool SpeakerEmbeddingManager::Add(const std::string &name,
|
||||
const float *p) const {
|
||||
return impl_->Add(name, p);
|
||||
}
|
||||
|
||||
bool SpeakerEmbeddingManager::Remove(const std::string &name) const {
|
||||
return impl_->Remove(name);
|
||||
}
|
||||
|
||||
std::string SpeakerEmbeddingManager::Search(const float *p,
|
||||
float threshold) const {
|
||||
return impl_->Search(p, threshold);
|
||||
}
|
||||
|
||||
bool SpeakerEmbeddingManager::Verify(const std::string &name, const float *p,
|
||||
float threshold) const {
|
||||
return impl_->Verify(name, p, threshold);
|
||||
}
|
||||
|
||||
int32_t SpeakerEmbeddingManager::NumSpeakers() const {
|
||||
return impl_->NumSpeakers();
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
72
sherpa-onnx/csrc/speaker-embedding-manager.h
Normal file
72
sherpa-onnx/csrc/speaker-embedding-manager.h
Normal file
@@ -0,0 +1,72 @@
|
||||
// sherpa-onnx/csrc/speaker-embedding-manager.h
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
#ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_MANAGER_H_
|
||||
#define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_MANAGER_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
class SpeakerEmbeddingManager {
|
||||
public:
|
||||
// @param dim Embedding dimension.
|
||||
explicit SpeakerEmbeddingManager(int32_t dim);
|
||||
~SpeakerEmbeddingManager();
|
||||
|
||||
/* Add the embedding and name of a speaker to the manager.
|
||||
*
|
||||
* @param name Name of the speaker
|
||||
* @param p Pointer to the embedding. Its length is `dim`.
|
||||
* @return Return true if added successfully. Return false if it failed.
|
||||
* At present, the only reason for a failure is that there is already
|
||||
* a speaker with the same `name`.
|
||||
*/
|
||||
bool Add(const std::string &name, const float *p) const;
|
||||
|
||||
/* Remove a speaker by its name.
|
||||
*
|
||||
* @param name Name of the speaker to remove.
|
||||
* @return Return true if it is removed successfully. Return false
|
||||
* if there is no such a speaker.
|
||||
*/
|
||||
bool Remove(const std::string &name) const;
|
||||
|
||||
/** It is for speaker identification.
|
||||
*
|
||||
* It computes the cosine similarity between and given embedding and all
|
||||
* other embeddings and find the embedding that has the largest score
|
||||
* and the score is above or equal to threshold. Return the speaker
|
||||
* name for the embedding if found; otherwise, it returns an empty string.
|
||||
*
|
||||
* @param p The input embedding.
|
||||
* @param threshold A value between 0 and 1.
|
||||
* @param If found, return the name of the speaker. Otherwise, return an
|
||||
* empty string.
|
||||
*/
|
||||
std::string Search(const float *p, float threshold) const;
|
||||
|
||||
/* Check whether the input embedding matches the embedding of the input
|
||||
* speaker.
|
||||
*
|
||||
* It is for speaker verification.
|
||||
*
|
||||
* @param name The target speaker name.
|
||||
* @param p The input embedding to check.
|
||||
* @param threshold A value between 0 and 1.
|
||||
* @return Return true if it matches. Otherwise, it returns false.
|
||||
*/
|
||||
bool Verify(const std::string &name, const float *p, float threshold) const;
|
||||
|
||||
int32_t NumSpeakers() const;
|
||||
|
||||
private:
|
||||
class Impl;
|
||||
std::unique_ptr<Impl> impl_;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_MANAGER_H_
|
||||
@@ -40,7 +40,9 @@ class VoiceActivityDetector::Impl {
|
||||
|
||||
for (int32_t i = 0; i != k; ++i, p += window_size) {
|
||||
buffer_.Push(p, window_size);
|
||||
is_speech = is_speech || model_->IsSpeech(p, window_size);
|
||||
// NOTE(fangjun): Please don't use a very large n.
|
||||
bool this_window_is_speech = model_->IsSpeech(p, window_size);
|
||||
is_speech = is_speech || this_window_is_speech;
|
||||
}
|
||||
|
||||
last_ = std::vector<float>(
|
||||
@@ -102,6 +104,8 @@ class VoiceActivityDetector::Impl {
|
||||
|
||||
bool IsSpeechDetected() const { return start_ != -1; }
|
||||
|
||||
const VadModelConfig &GetConfig() const { return config_; }
|
||||
|
||||
private:
|
||||
std::queue<SpeechSegment> segments_;
|
||||
|
||||
@@ -146,4 +150,8 @@ bool VoiceActivityDetector::IsSpeechDetected() const {
|
||||
return impl_->IsSpeechDetected();
|
||||
}
|
||||
|
||||
const VadModelConfig &VoiceActivityDetector::GetConfig() const {
|
||||
return impl_->GetConfig();
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
@@ -43,6 +43,8 @@ class VoiceActivityDetector {
|
||||
|
||||
void Reset();
|
||||
|
||||
const VadModelConfig &GetConfig() const;
|
||||
|
||||
private:
|
||||
class Impl;
|
||||
std::unique_ptr<Impl> impl_;
|
||||
|
||||
Reference in New Issue
Block a user