feat: find best embedding matches (#1102)
This commit is contained in:
@@ -1256,6 +1256,44 @@ void SherpaOnnxSpeakerEmbeddingManagerFreeSearch(const char *name) {
|
|||||||
delete[] name;
|
delete[] name;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const SherpaOnnxSpeakerEmbeddingManagerBestMatchesResult *
|
||||||
|
SherpaOnnxSpeakerEmbeddingManagerGetBestMatches(
|
||||||
|
const SherpaOnnxSpeakerEmbeddingManager *p, const float *v, float threshold,
|
||||||
|
int32_t n) {
|
||||||
|
auto matches = p->impl->GetBestMatches(v, threshold, n);
|
||||||
|
|
||||||
|
if (matches.empty()) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto resultMatches =
|
||||||
|
new SherpaOnnxSpeakerEmbeddingManagerSpeakerMatch[matches.size()];
|
||||||
|
for (int i = 0; i < matches.size(); ++i) {
|
||||||
|
resultMatches[i].score = matches[i].score;
|
||||||
|
|
||||||
|
char *name = new char[matches[i].name.size() + 1];
|
||||||
|
std::copy(matches[i].name.begin(), matches[i].name.end(), name);
|
||||||
|
name[matches[i].name.size()] = '\0';
|
||||||
|
|
||||||
|
resultMatches[i].name = name;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto *result = new SherpaOnnxSpeakerEmbeddingManagerBestMatchesResult();
|
||||||
|
result->count = matches.size();
|
||||||
|
result->matches = resultMatches;
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
void SherpaOnnxSpeakerEmbeddingManagerFreeBestMatches(
|
||||||
|
const SherpaOnnxSpeakerEmbeddingManagerBestMatchesResult *r) {
|
||||||
|
for (int32_t i = 0; i < r->count; ++i) {
|
||||||
|
delete[] r->matches[i].name;
|
||||||
|
}
|
||||||
|
delete[] r->matches;
|
||||||
|
delete r;
|
||||||
|
};
|
||||||
|
|
||||||
int32_t SherpaOnnxSpeakerEmbeddingManagerVerify(
|
int32_t SherpaOnnxSpeakerEmbeddingManagerVerify(
|
||||||
const SherpaOnnxSpeakerEmbeddingManager *p, const char *name,
|
const SherpaOnnxSpeakerEmbeddingManager *p, const char *name,
|
||||||
const float *v, float threshold) {
|
const float *v, float threshold) {
|
||||||
|
|||||||
@@ -1109,6 +1109,39 @@ SHERPA_ONNX_API const char *SherpaOnnxSpeakerEmbeddingManagerSearch(
|
|||||||
SHERPA_ONNX_API void SherpaOnnxSpeakerEmbeddingManagerFreeSearch(
|
SHERPA_ONNX_API void SherpaOnnxSpeakerEmbeddingManagerFreeSearch(
|
||||||
const char *name);
|
const char *name);
|
||||||
|
|
||||||
|
SHERPA_ONNX_API typedef struct SherpaOnnxSpeakerEmbeddingManagerSpeakerMatch {
|
||||||
|
float score;
|
||||||
|
const char *name;
|
||||||
|
} SherpaOnnxSpeakerEmbeddingManagerSpeakerMatch;
|
||||||
|
|
||||||
|
SHERPA_ONNX_API typedef struct
|
||||||
|
SherpaOnnxSpeakerEmbeddingManagerBestMatchesResult {
|
||||||
|
const SherpaOnnxSpeakerEmbeddingManagerSpeakerMatch *matches;
|
||||||
|
int32_t count;
|
||||||
|
} SherpaOnnxSpeakerEmbeddingManagerBestMatchesResult;
|
||||||
|
|
||||||
|
// Get the best matching speakers whose embeddings match the given
|
||||||
|
// embedding.
|
||||||
|
//
|
||||||
|
// @param p Pointer to the SherpaOnnxSpeakerEmbeddingManager instance.
|
||||||
|
// @param v Pointer to an array containing the embedding vector.
|
||||||
|
// @param threshold Minimum similarity score required for a match (between 0 and
|
||||||
|
// 1).
|
||||||
|
// @param n Number of best matches to retrieve.
|
||||||
|
// @return Returns a pointer to
|
||||||
|
// SherpaOnnxSpeakerEmbeddingManagerBestMatchesResult
|
||||||
|
// containing the best matches found. Returns NULL if no matches are
|
||||||
|
// found. The caller is responsible for freeing the returned pointer
|
||||||
|
// using SherpaOnnxSpeakerEmbeddingManagerFreeBestMatches() to
|
||||||
|
// avoid memory leaks.
|
||||||
|
SHERPA_ONNX_API const SherpaOnnxSpeakerEmbeddingManagerBestMatchesResult *
|
||||||
|
SherpaOnnxSpeakerEmbeddingManagerGetBestMatches(
|
||||||
|
const SherpaOnnxSpeakerEmbeddingManager *p, const float *v, float threshold,
|
||||||
|
int32_t n);
|
||||||
|
|
||||||
|
SHERPA_ONNX_API void SherpaOnnxSpeakerEmbeddingManagerFreeBestMatches(
|
||||||
|
const SherpaOnnxSpeakerEmbeddingManagerBestMatchesResult *r);
|
||||||
|
|
||||||
// Check whether the input embedding matches the embedding of the input
|
// Check whether the input embedding matches the embedding of the input
|
||||||
// speaker.
|
// speaker.
|
||||||
//
|
//
|
||||||
|
|||||||
@@ -131,6 +131,40 @@ class SpeakerEmbeddingManager::Impl {
|
|||||||
return row2name_.at(max_index);
|
return row2name_.at(max_index);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<SpeakerMatch> GetBestMatches(const float *p, float threshold,
|
||||||
|
int32_t n) {
|
||||||
|
std::vector<SpeakerMatch> matches;
|
||||||
|
|
||||||
|
if (embedding_matrix_.rows() == 0) {
|
||||||
|
return matches;
|
||||||
|
}
|
||||||
|
|
||||||
|
Eigen::VectorXf v =
|
||||||
|
Eigen::Map<Eigen::VectorXf>(const_cast<float *>(p), dim_);
|
||||||
|
v.normalize();
|
||||||
|
|
||||||
|
Eigen::VectorXf scores = embedding_matrix_ * v;
|
||||||
|
|
||||||
|
std::vector<std::pair<float, int>> score_indices;
|
||||||
|
for (int i = 0; i < scores.size(); ++i) {
|
||||||
|
if (scores[i] >= threshold) {
|
||||||
|
score_indices.emplace_back(scores[i], i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::sort(score_indices.rbegin(), score_indices.rend(),
|
||||||
|
[](const auto &a, const auto &b) { return a.first < b.first; });
|
||||||
|
|
||||||
|
matches.reserve(score_indices.size());
|
||||||
|
for (int i = 0; i < std::min(n, static_cast<int32_t>(score_indices.size()));
|
||||||
|
++i) {
|
||||||
|
const auto &pair = score_indices[i];
|
||||||
|
matches.push_back({row2name_.at(pair.second), pair.first});
|
||||||
|
}
|
||||||
|
|
||||||
|
return matches;
|
||||||
|
}
|
||||||
|
|
||||||
bool Verify(const std::string &name, const float *p, float threshold) {
|
bool Verify(const std::string &name, const float *p, float threshold) {
|
||||||
if (!name2row_.count(name)) {
|
if (!name2row_.count(name)) {
|
||||||
return false;
|
return false;
|
||||||
@@ -219,6 +253,11 @@ std::string SpeakerEmbeddingManager::Search(const float *p,
|
|||||||
return impl_->Search(p, threshold);
|
return impl_->Search(p, threshold);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<SpeakerMatch> SpeakerEmbeddingManager::GetBestMatches(
|
||||||
|
const float *p, float threshold, int32_t n) const {
|
||||||
|
return impl_->GetBestMatches(p, threshold, n);
|
||||||
|
}
|
||||||
|
|
||||||
bool SpeakerEmbeddingManager::Verify(const std::string &name, const float *p,
|
bool SpeakerEmbeddingManager::Verify(const std::string &name, const float *p,
|
||||||
float threshold) const {
|
float threshold) const {
|
||||||
return impl_->Verify(name, p, threshold);
|
return impl_->Verify(name, p, threshold);
|
||||||
|
|||||||
@@ -9,6 +9,11 @@
|
|||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
struct SpeakerMatch {
|
||||||
|
const std::string name;
|
||||||
|
float score;
|
||||||
|
};
|
||||||
|
|
||||||
namespace sherpa_onnx {
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
class SpeakerEmbeddingManager {
|
class SpeakerEmbeddingManager {
|
||||||
@@ -62,6 +67,25 @@ class SpeakerEmbeddingManager {
|
|||||||
*/
|
*/
|
||||||
std::string Search(const float *p, float threshold) const;
|
std::string Search(const float *p, float threshold) const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* It is for speaker identification.
|
||||||
|
*
|
||||||
|
* It computes the cosine similarity between a given embedding and all
|
||||||
|
* other embeddings and finds the embeddings that have the largest scores
|
||||||
|
* and the scores are above or equal to the threshold. Returns a vector of
|
||||||
|
* SpeakerMatch structures containing the speaker names and scores for the
|
||||||
|
* embeddings if found; otherwise, returns an empty vector.
|
||||||
|
*
|
||||||
|
* @param p A pointer to the input embedding.
|
||||||
|
* @param threshold A value between 0 and 1.
|
||||||
|
* @param n The number of top matches to return.
|
||||||
|
* @return A vector of SpeakerMatch structures. If matches are found, the
|
||||||
|
* vector contains the names and scores of the speakers. Otherwise,
|
||||||
|
* it returns an empty vector.
|
||||||
|
*/
|
||||||
|
std::vector<SpeakerMatch> GetBestMatches(const float *p, float threshold,
|
||||||
|
int32_t n) const;
|
||||||
|
|
||||||
/* Check whether the input embedding matches the embedding of the input
|
/* Check whether the input embedding matches the embedding of the input
|
||||||
* speaker.
|
* speaker.
|
||||||
*
|
*
|
||||||
|
|||||||
Reference in New Issue
Block a user