diff --git a/sherpa-onnx/csrc/speaker-embedding-manager.cc b/sherpa-onnx/csrc/speaker-embedding-manager.cc index e067a2eb..f1c5251d 100644 --- a/sherpa-onnx/csrc/speaker-embedding-manager.cc +++ b/sherpa-onnx/csrc/speaker-embedding-manager.cc @@ -151,6 +151,23 @@ class SpeakerEmbeddingManager::Impl { return true; } + float Score(const std::string &name, const float *p) { + if (!name2row_.count(name)) { + // Setting a default value if the name is not found + return -2.0; + } + + int32_t row_idx = name2row_.at(name); + + Eigen::VectorXf v = + Eigen::Map(const_cast(p), dim_); + v.normalize(); + + float score = embedding_matrix_.row(row_idx) * v; + + return score; + } + bool Contains(const std::string &name) const { return name2row_.count(name) > 0; } @@ -206,6 +223,11 @@ bool SpeakerEmbeddingManager::Verify(const std::string &name, const float *p, return impl_->Verify(name, p, threshold); } +float SpeakerEmbeddingManager::Score(const std::string &name, + const float *p) const { + return impl_->Score(name, p); +} + int32_t SpeakerEmbeddingManager::NumSpeakers() const { return impl_->NumSpeakers(); } diff --git a/sherpa-onnx/csrc/speaker-embedding-manager.h b/sherpa-onnx/csrc/speaker-embedding-manager.h index c1af12fc..ae8728b1 100644 --- a/sherpa-onnx/csrc/speaker-embedding-manager.h +++ b/sherpa-onnx/csrc/speaker-embedding-manager.h @@ -74,6 +74,8 @@ class SpeakerEmbeddingManager { */ bool Verify(const std::string &name, const float *p, float threshold) const; + float Score(const std::string &name, const float *p) const; + // Return true if the given speaker already exists; return false otherwise. bool Contains(const std::string &name) const; diff --git a/sherpa-onnx/python/csrc/speaker-embedding-manager.cc b/sherpa-onnx/python/csrc/speaker-embedding-manager.cc index b1580ec1..b7bc4e17 100644 --- a/sherpa-onnx/python/csrc/speaker-embedding-manager.cc +++ b/sherpa-onnx/python/csrc/speaker-embedding-manager.cc @@ -60,6 +60,14 @@ void PybindSpeakerEmbeddingManager(py::module *m) { return self.Verify(name, v.data(), threshold); }, py::arg("name"), py::arg("v"), py::arg("threshold"), + py::call_guard()) + .def( + "score", + [](const PyClass &self, const std::string &name, + const std::vector &v) -> float { + return self.Score(name, v.data()); + }, + py::arg("name"), py::arg("v"), py::call_guard()); }