feat: find best embedding matches (#1102)

This commit is contained in:
thewh1teagle
2024-07-11 04:38:06 +03:00
committed by GitHub
parent 1c104ea847
commit c0eaf86dbd
4 changed files with 134 additions and 0 deletions

View File

@@ -131,6 +131,40 @@ class SpeakerEmbeddingManager::Impl {
return row2name_.at(max_index);
}
std::vector<SpeakerMatch> GetBestMatches(const float *p, float threshold,
int32_t n) {
std::vector<SpeakerMatch> matches;
if (embedding_matrix_.rows() == 0) {
return matches;
}
Eigen::VectorXf v =
Eigen::Map<Eigen::VectorXf>(const_cast<float *>(p), dim_);
v.normalize();
Eigen::VectorXf scores = embedding_matrix_ * v;
std::vector<std::pair<float, int>> score_indices;
for (int i = 0; i < scores.size(); ++i) {
if (scores[i] >= threshold) {
score_indices.emplace_back(scores[i], i);
}
}
std::sort(score_indices.rbegin(), score_indices.rend(),
[](const auto &a, const auto &b) { return a.first < b.first; });
matches.reserve(score_indices.size());
for (int i = 0; i < std::min(n, static_cast<int32_t>(score_indices.size()));
++i) {
const auto &pair = score_indices[i];
matches.push_back({row2name_.at(pair.second), pair.first});
}
return matches;
}
bool Verify(const std::string &name, const float *p, float threshold) {
if (!name2row_.count(name)) {
return false;
@@ -219,6 +253,11 @@ std::string SpeakerEmbeddingManager::Search(const float *p,
return impl_->Search(p, threshold);
}
std::vector<SpeakerMatch> SpeakerEmbeddingManager::GetBestMatches(
const float *p, float threshold, int32_t n) const {
return impl_->GetBestMatches(p, threshold, n);
}
bool SpeakerEmbeddingManager::Verify(const std::string &name, const float *p,
float threshold) const {
return impl_->Verify(name, p, threshold);

View File

@@ -9,6 +9,11 @@
#include <string>
#include <vector>
struct SpeakerMatch {
const std::string name;
float score;
};
namespace sherpa_onnx {
class SpeakerEmbeddingManager {
@@ -62,6 +67,25 @@ class SpeakerEmbeddingManager {
*/
std::string Search(const float *p, float threshold) const;
/**
* It is for speaker identification.
*
* It computes the cosine similarity between a given embedding and all
* other embeddings and finds the embeddings that have the largest scores
* and the scores are above or equal to the threshold. Returns a vector of
* SpeakerMatch structures containing the speaker names and scores for the
* embeddings if found; otherwise, returns an empty vector.
*
* @param p A pointer to the input embedding.
* @param threshold A value between 0 and 1.
* @param n The number of top matches to return.
* @return A vector of SpeakerMatch structures. If matches are found, the
* vector contains the names and scores of the speakers. Otherwise,
* it returns an empty vector.
*/
std::vector<SpeakerMatch> GetBestMatches(const float *p, float threshold,
int32_t n) const;
/* Check whether the input embedding matches the embedding of the input
* speaker.
*