Add runtime support for wespeaker models (#516)

This commit is contained in:
Fangjun Kuang
2024-01-09 22:06:08 +08:00
committed by GitHub
parent 902b21894b
commit 55266918c8
27 changed files with 1291 additions and 4 deletions

View File

@@ -96,6 +96,14 @@ set(sources
wave-reader.cc
)
# speaker embedding extractor
list(APPEND sources
speaker-embedding-extractor-impl.cc
speaker-embedding-extractor-wespeaker-model.cc
speaker-embedding-extractor.cc
speaker-embedding-manager.cc
)
list(APPEND sources
lexicon.cc
offline-tts-impl.cc
@@ -387,6 +395,10 @@ if(SHERPA_ONNX_ENABLE_TESTS)
utfcpp-test.cc
)
list(APPEND sherpa_onnx_test_srcs
speaker-embedding-manager-test.cc
)
function(sherpa_onnx_add_test source)
get_filename_component(name ${source} NAME_WE)
set(target_name ${name})

View File

@@ -64,8 +64,8 @@ TEST(ContextGraph, Benchmark) {
auto stop = std::chrono::high_resolution_clock::now();
auto duration =
std::chrono::duration_cast<std::chrono::microseconds>(stop - start);
SHERPA_ONNX_LOGE("Construct context graph for %d item takes %ld us.", num,
duration.count());
SHERPA_ONNX_LOGE("Construct context graph for %d item takes %d us.", num,
static_cast<int32_t>(duration.count()));
}
}

View File

@@ -91,4 +91,9 @@ Ort::SessionOptions GetSessionOptions(const OfflineTtsModelConfig &config) {
return GetSessionOptionsImpl(config.num_threads, config.provider);
}
Ort::SessionOptions GetSessionOptions(
const SpeakerEmbeddingExtractorConfig &config) {
return GetSessionOptionsImpl(config.num_threads, config.provider);
}
} // namespace sherpa_onnx

View File

@@ -11,6 +11,7 @@
#include "sherpa-onnx/csrc/offline-tts-model-config.h"
#include "sherpa-onnx/csrc/online-lm-config.h"
#include "sherpa-onnx/csrc/online-model-config.h"
#include "sherpa-onnx/csrc/speaker-embedding-extractor.h"
#include "sherpa-onnx/csrc/vad-model-config.h"
namespace sherpa_onnx {
@@ -26,6 +27,9 @@ Ort::SessionOptions GetSessionOptions(const OnlineLMConfig &config);
Ort::SessionOptions GetSessionOptions(const VadModelConfig &config);
Ort::SessionOptions GetSessionOptions(const OfflineTtsModelConfig &config);
Ort::SessionOptions GetSessionOptions(
const SpeakerEmbeddingExtractorConfig &config);
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_SESSION_H_

View File

@@ -0,0 +1,82 @@
// sherpa-onnx/csrc/speaker-embedding-extractor-impl.cc
//
// Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/csrc/speaker-embedding-extractor-impl.h"
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
#include "sherpa-onnx/csrc/speaker-embedding-extractor-wespeaker-impl.h"
namespace sherpa_onnx {
namespace {
enum class ModelType {
kWeSpeaker,
kUnkown,
};
} // namespace
static ModelType GetModelType(char *model_data, size_t model_data_length,
bool debug) {
Ort::Env env(ORT_LOGGING_LEVEL_WARNING);
Ort::SessionOptions sess_opts;
auto sess = std::make_unique<Ort::Session>(env, model_data, model_data_length,
sess_opts);
Ort::ModelMetadata meta_data = sess->GetModelMetadata();
if (debug) {
std::ostringstream os;
PrintModelMetadata(os, meta_data);
SHERPA_ONNX_LOGE("%s", os.str().c_str());
}
Ort::AllocatorWithDefaultOptions allocator;
auto model_type =
meta_data.LookupCustomMetadataMapAllocated("framework", allocator);
if (!model_type) {
SHERPA_ONNX_LOGE(
"No model_type in the metadata!\n"
"Please make sure you have added metadata to the model.\n\n"
"For instance, you can use\n"
"https://github.com/k2-fsa/sherpa-onnx/blob/master/scripts/wespeaker/"
"add_meta_data.py"
"to add metadata to models from WeSpeaker\n");
return ModelType::kUnkown;
}
if (model_type.get() == std::string("wespeaker")) {
return ModelType::kWeSpeaker;
} else {
SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get());
return ModelType::kUnkown;
}
}
std::unique_ptr<SpeakerEmbeddingExtractorImpl>
SpeakerEmbeddingExtractorImpl::Create(
const SpeakerEmbeddingExtractorConfig &config) {
ModelType model_type = ModelType::kUnkown;
{
auto buffer = ReadFile(config.model);
model_type = GetModelType(buffer.data(), buffer.size(), config.debug);
}
switch (model_type) {
case ModelType::kWeSpeaker:
return std::make_unique<SpeakerEmbeddingExtractorWeSpeakerImpl>(config);
case ModelType::kUnkown:
SHERPA_ONNX_LOGE(
"Unknown model type in for speaker embedding extractor!");
return nullptr;
}
// unreachable code
return nullptr;
}
} // namespace sherpa_onnx

View File

@@ -0,0 +1,34 @@
// sherpa-onnx/csrc/speaker-embedding-extractor-impl.h
//
// Copyright (c) 2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_IMPL_H_
#define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_IMPL_H_
#include <memory>
#include <string>
#include <vector>
#include "sherpa-onnx/csrc/speaker-embedding-extractor.h"
namespace sherpa_onnx {
class SpeakerEmbeddingExtractorImpl {
public:
virtual ~SpeakerEmbeddingExtractorImpl() = default;
static std::unique_ptr<SpeakerEmbeddingExtractorImpl> Create(
const SpeakerEmbeddingExtractorConfig &config);
virtual int32_t Dim() const = 0;
virtual std::unique_ptr<OnlineStream> CreateStream() const = 0;
virtual bool IsReady(OnlineStream *s) const = 0;
virtual std::vector<float> Compute(OnlineStream *s) const = 0;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_IMPL_H_

View File

@@ -0,0 +1,79 @@
// sherpa-onnx/csrc/speaker-embedding-extractor-wespeaker-impl.h
//
// Copyright (c) 2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_WESPEAKER_IMPL_H_
#define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_WESPEAKER_IMPL_H_
#include <algorithm>
#include <memory>
#include <utility>
#include <vector>
#include "sherpa-onnx/csrc/speaker-embedding-extractor-impl.h"
#include "sherpa-onnx/csrc/speaker-embedding-extractor-wespeaker-model.h"
namespace sherpa_onnx {
class SpeakerEmbeddingExtractorWeSpeakerImpl
: public SpeakerEmbeddingExtractorImpl {
public:
explicit SpeakerEmbeddingExtractorWeSpeakerImpl(
const SpeakerEmbeddingExtractorConfig &config)
: model_(config) {}
int32_t Dim() const override { return model_.GetMetaData().output_dim; }
std::unique_ptr<OnlineStream> CreateStream() const override {
FeatureExtractorConfig feat_config;
auto meta_data = model_.GetMetaData();
feat_config.sampling_rate = meta_data.sample_rate;
feat_config.normalize_samples = meta_data.normalize_features;
return std::make_unique<OnlineStream>(feat_config);
}
bool IsReady(OnlineStream *s) const override {
return s->GetNumProcessedFrames() < s->NumFramesReady();
}
std::vector<float> Compute(OnlineStream *s) const override {
int32_t num_frames = s->NumFramesReady() - s->GetNumProcessedFrames();
if (num_frames <= 0) {
SHERPA_ONNX_LOGE(
"Please make sure IsReady(s) returns true. num_frames: %d",
num_frames);
return {};
}
std::vector<float> features =
s->GetFrames(s->GetNumProcessedFrames(), num_frames);
s->GetNumProcessedFrames() += num_frames;
int32_t feat_dim = features.size() / num_frames;
auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
std::array<int64_t, 3> x_shape{1, num_frames, feat_dim};
Ort::Value x =
Ort::Value::CreateTensor(memory_info, features.data(), features.size(),
x_shape.data(), x_shape.size());
Ort::Value embedding = model_.Compute(std::move(x));
std::vector<int64_t> embedding_shape =
embedding.GetTensorTypeAndShapeInfo().GetShape();
std::vector<float> ans(embedding_shape[1]);
std::copy(embedding.GetTensorData<float>(),
embedding.GetTensorData<float>() + ans.size(), ans.begin());
return ans;
}
private:
SpeakerEmbeddingExtractorWeSpeakerModel model_;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_WESPEAKER_IMPL_H_

View File

@@ -0,0 +1,20 @@
// sherpa-onnx/csrc/speaker-embedding-extractor-wespeaker-model-metadata.h
//
// Copyright (c) 2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_WESPEAKER_MODEL_METADATA_H_
#define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_WESPEAKER_MODEL_METADATA_H_
#include <cstdint>
#include <string>
namespace sherpa_onnx {
struct SpeakerEmbeddingExtractorWeSpeakerModelMetaData {
int32_t output_dim = 0;
int32_t sample_rate = 0;
int32_t normalize_features = 0;
std::string language;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_WESPEAKER_MODEL_METADATA_H_

View File

@@ -0,0 +1,112 @@
// sherpa-onnx/csrc/speaker-embedding-extractor-wespeaker-model.cc
//
// Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/csrc/speaker-embedding-extractor-wespeaker-model.h"
#include <string>
#include <utility>
#include <vector>
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
#include "sherpa-onnx/csrc/session.h"
#include "sherpa-onnx/csrc/speaker-embedding-extractor-wespeaker-model-metadata.h"
namespace sherpa_onnx {
class SpeakerEmbeddingExtractorWeSpeakerModel::Impl {
public:
explicit Impl(const SpeakerEmbeddingExtractorConfig &config)
: config_(config),
env_(ORT_LOGGING_LEVEL_ERROR),
sess_opts_(GetSessionOptions(config)),
allocator_{} {
{
auto buf = ReadFile(config.model);
Init(buf.data(), buf.size());
}
}
Ort::Value Compute(Ort::Value x) const {
std::array<Ort::Value, 1> inputs = {std::move(x)};
auto outputs =
sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(),
output_names_ptr_.data(), output_names_ptr_.size());
return std::move(outputs[0]);
}
const SpeakerEmbeddingExtractorWeSpeakerModelMetaData &GetMetaData() const {
return meta_data_;
}
private:
void Init(void *model_data, size_t model_data_length) {
sess_ = std::make_unique<Ort::Session>(env_, model_data, model_data_length,
sess_opts_);
GetInputNames(sess_.get(), &input_names_, &input_names_ptr_);
GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_);
// get meta data
Ort::ModelMetadata meta_data = sess_->GetModelMetadata();
if (config_.debug) {
std::ostringstream os;
PrintModelMetadata(os, meta_data);
SHERPA_ONNX_LOGE("%s", os.str().c_str());
}
Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
SHERPA_ONNX_READ_META_DATA(meta_data_.output_dim, "output_dim");
SHERPA_ONNX_READ_META_DATA(meta_data_.sample_rate, "sample_rate");
SHERPA_ONNX_READ_META_DATA(meta_data_.normalize_features,
"normalize_features");
SHERPA_ONNX_READ_META_DATA_STR(meta_data_.language, "language");
std::string framework;
SHERPA_ONNX_READ_META_DATA_STR(framework, "framework");
if (framework != "wespeaker") {
SHERPA_ONNX_LOGE("Expect a wespeaker model, given: %s",
framework.c_str());
exit(-1);
}
}
private:
SpeakerEmbeddingExtractorConfig config_;
Ort::Env env_;
Ort::SessionOptions sess_opts_;
Ort::AllocatorWithDefaultOptions allocator_;
std::unique_ptr<Ort::Session> sess_;
std::vector<std::string> input_names_;
std::vector<const char *> input_names_ptr_;
std::vector<std::string> output_names_;
std::vector<const char *> output_names_ptr_;
SpeakerEmbeddingExtractorWeSpeakerModelMetaData meta_data_;
};
SpeakerEmbeddingExtractorWeSpeakerModel::
SpeakerEmbeddingExtractorWeSpeakerModel(
const SpeakerEmbeddingExtractorConfig &config)
: impl_(std::make_unique<Impl>(config)) {}
SpeakerEmbeddingExtractorWeSpeakerModel::
~SpeakerEmbeddingExtractorWeSpeakerModel() = default;
const SpeakerEmbeddingExtractorWeSpeakerModelMetaData &
SpeakerEmbeddingExtractorWeSpeakerModel::GetMetaData() const {
return impl_->GetMetaData();
}
Ort::Value SpeakerEmbeddingExtractorWeSpeakerModel::Compute(
Ort::Value x) const {
return impl_->Compute(std::move(x));
}
} // namespace sherpa_onnx

View File

@@ -0,0 +1,37 @@
// sherpa-onnx/csrc/speaker-embedding-extractor-wespeaker-model.h
//
// Copyright (c) 2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_WESPEAKER_MODEL_H_
#define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_WESPEAKER_MODEL_H_
#include <memory>
#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/speaker-embedding-extractor-wespeaker-model-metadata.h"
#include "sherpa-onnx/csrc/speaker-embedding-extractor.h"
namespace sherpa_onnx {
class SpeakerEmbeddingExtractorWeSpeakerModel {
public:
explicit SpeakerEmbeddingExtractorWeSpeakerModel(
const SpeakerEmbeddingExtractorConfig &config);
~SpeakerEmbeddingExtractorWeSpeakerModel();
const SpeakerEmbeddingExtractorWeSpeakerModelMetaData &GetMetaData() const;
/**
* @param x A float32 tensor of shape (N, T, C)
* @return A float32 tensor of shape (N, C)
*/
Ort::Value Compute(Ort::Value x) const;
private:
class Impl;
std::unique_ptr<Impl> impl_;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_WESPEAKER_MODEL_H_

View File

@@ -0,0 +1,74 @@
// sherpa-onnx/csrc/speaker-embedding-extractor.cc
//
// Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/csrc/speaker-embedding-extractor.h"
#include <vector>
#include "sherpa-onnx/csrc/file-utils.h"
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/speaker-embedding-extractor-impl.h"
namespace sherpa_onnx {
void SpeakerEmbeddingExtractorConfig::Register(ParseOptions *po) {
po->Register("model", &model, "Path to the speaker embedding model.");
po->Register("num-threads", &num_threads,
"Number of threads to run the neural network");
po->Register("debug", &debug,
"true to print model information while loading it.");
po->Register("provider", &provider,
"Specify a provider to use: cpu, cuda, coreml");
}
bool SpeakerEmbeddingExtractorConfig::Validate() const {
if (model.empty()) {
SHERPA_ONNX_LOGE("Please provide --speaker-embedding-model");
return false;
}
if (!FileExists(model)) {
SHERPA_ONNX_LOGE("--speaker-embedding-model: %s does not exist",
model.c_str());
return false;
}
return true;
}
std::string SpeakerEmbeddingExtractorConfig::ToString() const {
std::ostringstream os;
os << "SpeakerEmbeddingExtractorConfig(";
os << "model=\"" << model << "\", ";
os << "num_threads=" << num_threads << ", ";
os << "debug=" << (debug ? "True" : "False") << ", ";
os << "provider=\"" << provider << "\")";
return os.str();
}
SpeakerEmbeddingExtractor::SpeakerEmbeddingExtractor(
const SpeakerEmbeddingExtractorConfig &config)
: impl_(SpeakerEmbeddingExtractorImpl::Create(config)) {}
SpeakerEmbeddingExtractor::~SpeakerEmbeddingExtractor() = default;
int32_t SpeakerEmbeddingExtractor::Dim() const { return impl_->Dim(); }
std::unique_ptr<OnlineStream> SpeakerEmbeddingExtractor::CreateStream() const {
return impl_->CreateStream();
}
bool SpeakerEmbeddingExtractor::IsReady(OnlineStream *s) const {
return impl_->IsReady(s);
}
std::vector<float> SpeakerEmbeddingExtractor::Compute(OnlineStream *s) const {
return impl_->Compute(s);
}
} // namespace sherpa_onnx

View File

@@ -0,0 +1,67 @@
// sherpa-onnx/csrc/speaker-embedding-extractor.h
//
// Copyright (c) 2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_H_
#define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_H_
#include <memory>
#include <string>
#include <vector>
#include "sherpa-onnx/csrc/online-stream.h"
#include "sherpa-onnx/csrc/parse-options.h"
namespace sherpa_onnx {
struct SpeakerEmbeddingExtractorConfig {
std::string model;
int32_t num_threads = 1;
bool debug = false;
std::string provider = "cpu";
SpeakerEmbeddingExtractorConfig() = default;
SpeakerEmbeddingExtractorConfig(const std::string &model, int32_t num_threads,
bool debug, const std::string &provider)
: model(model),
num_threads(num_threads),
debug(debug),
provider(provider) {}
void Register(ParseOptions *po);
bool Validate() const;
std::string ToString() const;
};
class SpeakerEmbeddingExtractorImpl;
class SpeakerEmbeddingExtractor {
public:
explicit SpeakerEmbeddingExtractor(
const SpeakerEmbeddingExtractorConfig &config);
~SpeakerEmbeddingExtractor();
// Return the dimension of the embedding
int32_t Dim() const;
// Create a stream to accept audio samples and compute features
std::unique_ptr<OnlineStream> CreateStream() const;
// Return true if there are feature frames in OnlineStream that
// can be used to compute embeddings.
bool IsReady(OnlineStream *s) const;
// Compute the speaker embedding from the available unprocessed features
// of the given stream
//
// You have to ensure IsReady(s) returns true before you call this method.
std::vector<float> Compute(OnlineStream *s) const;
private:
std::unique_ptr<SpeakerEmbeddingExtractorImpl> impl_;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_H_

View File

@@ -0,0 +1,147 @@
// sherpa-onnx/csrc/speaker-embedding-manager-test.cc
//
// Copyright (c) 2023 Jingzhao Ou (jingzhao.ou@gmail.com)
#include "sherpa-onnx/csrc/speaker-embedding-manager.h"
#include "gtest/gtest.h"
namespace sherpa_onnx {
TEST(SpeakerEmbeddingManager, AddAndRemove) {
int32_t dim = 2;
SpeakerEmbeddingManager manager(dim);
std::vector<float> v = {0.1, 0.1};
bool status = manager.Add("first", v.data());
ASSERT_TRUE(status);
ASSERT_EQ(manager.NumSpeakers(), 1);
// duplicate
status = manager.Add("first", v.data());
ASSERT_FALSE(status);
ASSERT_EQ(manager.NumSpeakers(), 1);
// non-duplicate
v = {0.1, 0.9};
status = manager.Add("second", v.data());
ASSERT_TRUE(status);
ASSERT_EQ(manager.NumSpeakers(), 2);
// do not exist
status = manager.Remove("third");
ASSERT_FALSE(status);
status = manager.Remove("first");
ASSERT_TRUE(status);
ASSERT_EQ(manager.NumSpeakers(), 1);
v = {0.1, 0.1};
status = manager.Add("first", v.data());
ASSERT_TRUE(status);
ASSERT_EQ(manager.NumSpeakers(), 2);
status = manager.Remove("first");
ASSERT_TRUE(status);
ASSERT_EQ(manager.NumSpeakers(), 1);
status = manager.Remove("second");
ASSERT_TRUE(status);
ASSERT_EQ(manager.NumSpeakers(), 0);
}
TEST(SpeakerEmbeddingManager, Search) {
int32_t dim = 2;
SpeakerEmbeddingManager manager(dim);
std::vector<float> v1 = {0.1, 0.1};
std::vector<float> v2 = {0.1, 0.9};
std::vector<float> v3 = {0.9, 0.1};
bool status = manager.Add("first", v1.data());
ASSERT_TRUE(status);
status = manager.Add("second", v2.data());
ASSERT_TRUE(status);
status = manager.Add("third", v3.data());
ASSERT_TRUE(status);
ASSERT_EQ(manager.NumSpeakers(), 3);
std::vector<float> v = {15, 16};
float threshold = 0.9;
std::string name = manager.Search(v.data(), threshold);
EXPECT_EQ(name, "first");
v = {2, 17};
name = manager.Search(v.data(), threshold);
EXPECT_EQ(name, "second");
v = {17, 2};
name = manager.Search(v.data(), threshold);
EXPECT_EQ(name, "third");
threshold = 0.9;
v = {15, 16};
status = manager.Remove("first");
ASSERT_TRUE(status);
name = manager.Search(v.data(), threshold);
EXPECT_EQ(name, "");
v = {17, 2};
status = manager.Remove("third");
ASSERT_TRUE(status);
name = manager.Search(v.data(), threshold);
EXPECT_EQ(name, "");
v = {2, 17};
status = manager.Remove("second");
ASSERT_TRUE(status);
name = manager.Search(v.data(), threshold);
EXPECT_EQ(name, "");
ASSERT_EQ(manager.NumSpeakers(), 0);
}
TEST(SpeakerEmbeddingManager, Verify) {
int32_t dim = 2;
SpeakerEmbeddingManager manager(dim);
std::vector<float> v1 = {0.1, 0.1};
std::vector<float> v2 = {0.1, 0.9};
std::vector<float> v3 = {0.9, 0.1};
bool status = manager.Add("first", v1.data());
ASSERT_TRUE(status);
status = manager.Add("second", v2.data());
ASSERT_TRUE(status);
status = manager.Add("third", v3.data());
ASSERT_TRUE(status);
std::vector<float> v = {15, 16};
float threshold = 0.9;
status = manager.Verify("first", v.data(), threshold);
ASSERT_TRUE(status);
v = {2, 17};
status = manager.Verify("first", v.data(), threshold);
ASSERT_FALSE(status);
status = manager.Verify("second", v.data(), threshold);
ASSERT_TRUE(status);
v = {17, 2};
status = manager.Verify("first", v.data(), threshold);
ASSERT_FALSE(status);
status = manager.Verify("second", v.data(), threshold);
ASSERT_FALSE(status);
status = manager.Verify("third", v.data(), threshold);
ASSERT_TRUE(status);
status = manager.Verify("fourth", v.data(), threshold);
ASSERT_FALSE(status);
}
} // namespace sherpa_onnx

View File

@@ -0,0 +1,144 @@
// sherpa-onnx/csrc/speaker-embedding-manager.cc
//
// Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/csrc/speaker-embedding-manager.h"
#include <algorithm>
#include <unordered_map>
#include "Eigen/Dense"
namespace sherpa_onnx {
using FloatMatrix =
Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
class SpeakerEmbeddingManager::Impl {
public:
explicit Impl(int32_t dim) : dim_(dim) {}
bool Add(const std::string &name, const float *p) {
if (name2row_.count(name)) {
// a speaker with the same name already exists
return false;
}
embedding_matrix_.conservativeResize(embedding_matrix_.rows() + 1, dim_);
std::copy(p, p + dim_, &embedding_matrix_.bottomRows(1)(0, 0));
embedding_matrix_.bottomRows(1).normalize(); // inplace
name2row_[name] = embedding_matrix_.rows() - 1;
row2name_[embedding_matrix_.rows() - 1] = name;
return true;
}
bool Remove(const std::string &name) {
if (!name2row_.count(name)) {
return false;
}
int32_t row_idx = name2row_.at(name);
int32_t num_rows = embedding_matrix_.rows();
if (row_idx < num_rows - 1) {
embedding_matrix_.block(row_idx, 0, num_rows - -1 - row_idx, dim_) =
embedding_matrix_.bottomRows(num_rows - 1 - row_idx);
}
embedding_matrix_.conservativeResize(num_rows - 1, dim_);
for (auto &p : name2row_) {
if (p.second > row_idx) {
p.second -= 1;
row2name_[p.second] = p.first;
}
}
name2row_.erase(name);
row2name_.erase(num_rows - 1);
return true;
}
std::string Search(const float *p, float threshold) {
if (embedding_matrix_.rows() == 0) {
return {};
}
Eigen::VectorXf v =
Eigen::Map<Eigen::VectorXf>(const_cast<float *>(p), dim_);
v.normalize();
Eigen::VectorXf scores = embedding_matrix_ * v;
Eigen::VectorXf::Index max_index;
float max_score = scores.maxCoeff(&max_index);
if (max_score < threshold) {
return {};
}
return row2name_.at(max_index);
}
bool Verify(const std::string &name, const float *p, float threshold) {
if (!name2row_.count(name)) {
return false;
}
int32_t row_idx = name2row_.at(name);
Eigen::VectorXf v =
Eigen::Map<Eigen::VectorXf>(const_cast<float *>(p), dim_);
v.normalize();
float score = embedding_matrix_.row(row_idx) * v;
if (score < threshold) {
return false;
}
return true;
}
int32_t NumSpeakers() const { return embedding_matrix_.rows(); }
private:
int32_t dim_;
FloatMatrix embedding_matrix_;
std::unordered_map<std::string, int32_t> name2row_;
std::unordered_map<int32_t, std::string> row2name_;
};
SpeakerEmbeddingManager::SpeakerEmbeddingManager(int32_t dim)
: impl_(std::make_unique<Impl>(dim)) {}
SpeakerEmbeddingManager::~SpeakerEmbeddingManager() = default;
bool SpeakerEmbeddingManager::Add(const std::string &name,
const float *p) const {
return impl_->Add(name, p);
}
bool SpeakerEmbeddingManager::Remove(const std::string &name) const {
return impl_->Remove(name);
}
std::string SpeakerEmbeddingManager::Search(const float *p,
float threshold) const {
return impl_->Search(p, threshold);
}
bool SpeakerEmbeddingManager::Verify(const std::string &name, const float *p,
float threshold) const {
return impl_->Verify(name, p, threshold);
}
int32_t SpeakerEmbeddingManager::NumSpeakers() const {
return impl_->NumSpeakers();
}
} // namespace sherpa_onnx

View File

@@ -0,0 +1,72 @@
// sherpa-onnx/csrc/speaker-embedding-manager.h
//
// Copyright (c) 2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_MANAGER_H_
#define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_MANAGER_H_
#include <memory>
#include <string>
namespace sherpa_onnx {
class SpeakerEmbeddingManager {
public:
// @param dim Embedding dimension.
explicit SpeakerEmbeddingManager(int32_t dim);
~SpeakerEmbeddingManager();
/* Add the embedding and name of a speaker to the manager.
*
* @param name Name of the speaker
* @param p Pointer to the embedding. Its length is `dim`.
* @return Return true if added successfully. Return false if it failed.
* At present, the only reason for a failure is that there is already
* a speaker with the same `name`.
*/
bool Add(const std::string &name, const float *p) const;
/* Remove a speaker by its name.
*
* @param name Name of the speaker to remove.
* @return Return true if it is removed successfully. Return false
* if there is no such a speaker.
*/
bool Remove(const std::string &name) const;
/** It is for speaker identification.
*
* It computes the cosine similarity between and given embedding and all
* other embeddings and find the embedding that has the largest score
* and the score is above or equal to threshold. Return the speaker
* name for the embedding if found; otherwise, it returns an empty string.
*
* @param p The input embedding.
* @param threshold A value between 0 and 1.
* @param If found, return the name of the speaker. Otherwise, return an
* empty string.
*/
std::string Search(const float *p, float threshold) const;
/* Check whether the input embedding matches the embedding of the input
* speaker.
*
* It is for speaker verification.
*
* @param name The target speaker name.
* @param p The input embedding to check.
* @param threshold A value between 0 and 1.
* @return Return true if it matches. Otherwise, it returns false.
*/
bool Verify(const std::string &name, const float *p, float threshold) const;
int32_t NumSpeakers() const;
private:
class Impl;
std::unique_ptr<Impl> impl_;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_MANAGER_H_

View File

@@ -40,7 +40,9 @@ class VoiceActivityDetector::Impl {
for (int32_t i = 0; i != k; ++i, p += window_size) {
buffer_.Push(p, window_size);
is_speech = is_speech || model_->IsSpeech(p, window_size);
// NOTE(fangjun): Please don't use a very large n.
bool this_window_is_speech = model_->IsSpeech(p, window_size);
is_speech = is_speech || this_window_is_speech;
}
last_ = std::vector<float>(
@@ -102,6 +104,8 @@ class VoiceActivityDetector::Impl {
bool IsSpeechDetected() const { return start_ != -1; }
const VadModelConfig &GetConfig() const { return config_; }
private:
std::queue<SpeechSegment> segments_;
@@ -146,4 +150,8 @@ bool VoiceActivityDetector::IsSpeechDetected() const {
return impl_->IsSpeechDetected();
}
const VadModelConfig &VoiceActivityDetector::GetConfig() const {
return impl_->GetConfig();
}
} // namespace sherpa_onnx

View File

@@ -43,6 +43,8 @@ class VoiceActivityDetector {
void Reset();
const VadModelConfig &GetConfig() const;
private:
class Impl;
std::unique_ptr<Impl> impl_;

View File

@@ -30,6 +30,8 @@ pybind11_add_module(_sherpa_onnx
online-zipformer2-ctc-model-config.cc
sherpa-onnx.cc
silero-vad-model-config.cc
speaker-embedding-extractor.cc
speaker-embedding-manager.cc
vad-model-config.cc
vad-model.cc
voice-activity-detector.cc

View File

@@ -1,4 +1,4 @@
// sherpa-onnx/python/csrc/online-recongizer.h
// sherpa-onnx/python/csrc/online-recognizer.h
//
// Copyright (c) 2023 Xiaomi Corporation

View File

@@ -18,6 +18,8 @@
#include "sherpa-onnx/python/csrc/online-model-config.h"
#include "sherpa-onnx/python/csrc/online-recognizer.h"
#include "sherpa-onnx/python/csrc/online-stream.h"
#include "sherpa-onnx/python/csrc/speaker-embedding-extractor.h"
#include "sherpa-onnx/python/csrc/speaker-embedding-manager.h"
#include "sherpa-onnx/python/csrc/vad-model-config.h"
#include "sherpa-onnx/python/csrc/vad-model.h"
#include "sherpa-onnx/python/csrc/voice-activity-detector.h"
@@ -48,6 +50,8 @@ PYBIND11_MODULE(_sherpa_onnx, m) {
PybindVoiceActivityDetector(&m);
PybindOfflineTts(&m);
PybindSpeakerEmbeddingExtractor(&m);
PybindSpeakerEmbeddingManager(&m);
}
} // namespace sherpa_onnx

View File

@@ -0,0 +1,44 @@
// sherpa-onnx/python/csrc/speaker-embedding-extractor.cc
//
// Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/python/csrc/speaker-embedding-extractor.h"
#include <string>
#include "sherpa-onnx/csrc/speaker-embedding-extractor.h"
namespace sherpa_onnx {
static void PybindSpeakerEmbeddingExtractorConfig(py::module *m) {
using PyClass = SpeakerEmbeddingExtractorConfig;
py::class_<PyClass>(*m, "SpeakerEmbeddingExtractorConfig")
.def(py::init<>())
.def(py::init<const std::string &, int32_t, bool, const std::string>(),
py::arg("model"), py::arg("num_threads") = 1,
py::arg("debug") = false, py::arg("provider") = "cpu")
.def_readwrite("model", &PyClass::model)
.def_readwrite("num_threads", &PyClass::num_threads)
.def_readwrite("debug", &PyClass::debug)
.def_readwrite("provider", &PyClass::provider)
.def("validate", &PyClass::Validate)
.def("__str__", &PyClass::ToString);
}
void PybindSpeakerEmbeddingExtractor(py::module *m) {
PybindSpeakerEmbeddingExtractorConfig(m);
using PyClass = SpeakerEmbeddingExtractor;
py::class_<PyClass>(*m, "SpeakerEmbeddingExtractor")
.def(py::init<const SpeakerEmbeddingExtractorConfig &>(),
py::arg("config"), py::call_guard<py::gil_scoped_release>())
.def_property_readonly("dim", &PyClass::Dim)
.def("create_stream", &PyClass::CreateStream,
py::call_guard<py::gil_scoped_release>())
.def("compute", &PyClass::Compute,
py::call_guard<py::gil_scoped_release>())
.def("is_ready", &PyClass::IsReady,
py::call_guard<py::gil_scoped_release>());
}
} // namespace sherpa_onnx

View File

@@ -0,0 +1,16 @@
// sherpa-onnx/python/csrc/speaker-embedding-extractor.h
//
// Copyright (c) 2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_PYTHON_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_H_
#define SHERPA_ONNX_PYTHON_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_H_
#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
namespace sherpa_onnx {
void PybindSpeakerEmbeddingExtractor(py::module *m);
}
#endif // SHERPA_ONNX_PYTHON_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_H_

View File

@@ -0,0 +1,50 @@
// sherpa-onnx/python/csrc/speaker-embedding-manager.cc
//
// Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/python/csrc/speaker-embedding-manager.h"
#include <string>
#include <vector>
#include "sherpa-onnx/csrc/speaker-embedding-manager.h"
namespace sherpa_onnx {
void PybindSpeakerEmbeddingManager(py::module *m) {
using PyClass = SpeakerEmbeddingManager;
py::class_<PyClass>(*m, "SpeakerEmbeddingManager")
.def(py::init<int32_t>(), py::arg("dim"),
py::call_guard<py::gil_scoped_release>())
.def_property_readonly("num_speakers", &PyClass::NumSpeakers)
.def(
"add",
[](const PyClass &self, const std::string &name,
const std::vector<float> &v) -> bool {
return self.Add(name, v.data());
},
py::arg("name"), py::arg("v"),
py::call_guard<py::gil_scoped_release>())
.def(
"remove",
[](const PyClass &self, const std::string &name) -> bool {
return self.Remove(name);
},
py::arg("name"), py::call_guard<py::gil_scoped_release>())
.def(
"search",
[](const PyClass &self, const std::vector<float> &v, float threshold)
-> std::string { return self.Search(v.data(), threshold); },
py::arg("v"), py::arg("threshold"),
py::call_guard<py::gil_scoped_release>())
.def(
"verify",
[](const PyClass &self, const std::string &name,
const std::vector<float> &v, float threshold) -> bool {
return self.Verify(name, v.data(), threshold);
},
py::arg("name"), py::arg("v"), py::arg("threshold"),
py::call_guard<py::gil_scoped_release>());
}
} // namespace sherpa_onnx

View File

@@ -0,0 +1,16 @@
// sherpa-onnx/python/csrc/speaker-embedding-manager.h
//
// Copyright (c) 2023 Xiaomi Corporation
#ifndef SHERPA_ONNX_PYTHON_CSRC_SPEAKER_EMBEDDING_MANAGER_H_
#define SHERPA_ONNX_PYTHON_CSRC_SPEAKER_EMBEDDING_MANAGER_H_
#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
namespace sherpa_onnx {
void PybindSpeakerEmbeddingManager(py::module *m);
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_PYTHON_CSRC_SPEAKER_EMBEDDING_MANAGER_H_

View File

@@ -32,6 +32,7 @@ void PybindVoiceActivityDetector(py::module *m) {
self.AcceptWaveform(samples.data(), samples.size());
},
py::arg("samples"), py::call_guard<py::gil_scoped_release>())
.def_property_readonly("config", &PyClass::GetConfig)
.def("empty", &PyClass::Empty, py::call_guard<py::gil_scoped_release>())
.def("pop", &PyClass::Pop, py::call_guard<py::gil_scoped_release>())
.def("is_speech_detected", &PyClass::IsSpeechDetected,

View File

@@ -8,6 +8,9 @@ from _sherpa_onnx import (
OfflineTtsVitsModelConfig,
OnlineStream,
SileroVadModelConfig,
SpeakerEmbeddingExtractor,
SpeakerEmbeddingExtractorConfig,
SpeakerEmbeddingManager,
SpeechSegment,
VadModel,
VadModelConfig,