Handle NaN embeddings in speaker diarization. (#1461)

See also https://github.com/thewh1teagle/sherpa-rs/issues/33
This commit is contained in:
Fangjun Kuang
2024-10-24 14:03:09 +08:00
committed by GitHub
parent b3e05f6dc4
commit a5295aad10
6 changed files with 48 additions and 7 deletions

View File

@@ -5,6 +5,7 @@
#define SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_DIARIZATION_PYANNOTE_IMPL_H_
#include <algorithm>
#include <cmath>
#include <memory>
#include <unordered_map>
#include <utility>
@@ -135,9 +136,32 @@ class OfflineSpeakerDiarizationPyannoteImpl
}
auto chunk_speaker_samples_list_pair = GetChunkSpeakerSampleIndexes(labels);
// The embedding model may output NaN. valid_indexes contains indexes
// in chunk_speaker_samples_list_pair.second that don't lead to
// NaN embeddings.
std::vector<int32_t> valid_indexes;
valid_indexes.reserve(chunk_speaker_samples_list_pair.second.size());
Matrix2D embeddings =
ComputeEmbeddings(audio, n, chunk_speaker_samples_list_pair.second,
std::move(callback), callback_arg);
&valid_indexes, std::move(callback), callback_arg);
if (valid_indexes.size() != chunk_speaker_samples_list_pair.second.size()) {
std::vector<Int32Pair> chunk_speaker_pair;
std::vector<std::vector<Int32Pair>> sample_indexes;
chunk_speaker_pair.reserve(valid_indexes.size());
sample_indexes.reserve(valid_indexes.size());
for (auto i : valid_indexes) {
chunk_speaker_pair.push_back(chunk_speaker_samples_list_pair.first[i]);
sample_indexes.push_back(
std::move(chunk_speaker_samples_list_pair.second[i]));
}
chunk_speaker_samples_list_pair.first = std::move(chunk_speaker_pair);
chunk_speaker_samples_list_pair.second = std::move(sample_indexes);
}
std::vector<int32_t> cluster_labels = clustering_->Cluster(
&embeddings(0, 0), embeddings.rows(), embeddings.cols());
@@ -431,13 +455,17 @@ class OfflineSpeakerDiarizationPyannoteImpl
Matrix2D ComputeEmbeddings(
const float *audio, int32_t n,
const std::vector<std::vector<Int32Pair>> &sample_indexes,
std::vector<int32_t> *valid_indexes,
OfflineSpeakerDiarizationProgressCallback callback,
void *callback_arg) const {
const auto &meta_data = segmentation_model_.GetModelMetaData();
int32_t sample_rate = meta_data.sample_rate;
Matrix2D ans(sample_indexes.size(), embedding_extractor_.Dim());
auto IsNaNWrapper = [](float f) -> bool { return std::isnan(f); };
int32_t k = 0;
int32_t cur_row_index = 0;
for (const auto &v : sample_indexes) {
auto stream = embedding_extractor_.CreateStream();
for (const auto &p : v) {
@@ -459,7 +487,12 @@ class OfflineSpeakerDiarizationPyannoteImpl
std::vector<float> embedding = embedding_extractor_.Compute(stream.get());
std::copy(embedding.begin(), embedding.end(), &ans(k, 0));
if (std::none_of(embedding.begin(), embedding.end(), IsNaNWrapper)) {
// a valid embedding
std::copy(embedding.begin(), embedding.end(), &ans(cur_row_index, 0));
cur_row_index += 1;
valid_indexes->push_back(k);
}
k += 1;
@@ -468,6 +501,11 @@ class OfflineSpeakerDiarizationPyannoteImpl
}
}
if (k != cur_row_index) {
auto seq = Eigen::seqN(0, cur_row_index);
ans = ans(seq, Eigen::all);
}
return ans;
}

View File

@@ -122,7 +122,7 @@ class SpeakerEmbeddingExtractorNeMoImpl : public SpeakerEmbeddingExtractorImpl {
auto variance = EX2 - EX.array().pow(2);
auto stddev = variance.array().sqrt();
m = (m.rowwise() - EX).array().rowwise() / stddev.array();
m = (m.rowwise() - EX).array().rowwise() / (stddev.array() + 1e-5);
}
private: