Handle NaN embeddings in speaker diarization. (#1461)
See also https://github.com/thewh1teagle/sherpa-rs/issues/33
This commit is contained in:
@@ -5,6 +5,7 @@
|
||||
#define SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_DIARIZATION_PYANNOTE_IMPL_H_
|
||||
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
@@ -135,9 +136,32 @@ class OfflineSpeakerDiarizationPyannoteImpl
|
||||
}
|
||||
|
||||
auto chunk_speaker_samples_list_pair = GetChunkSpeakerSampleIndexes(labels);
|
||||
|
||||
// The embedding model may output NaN. valid_indexes contains indexes
|
||||
// in chunk_speaker_samples_list_pair.second that don't lead to
|
||||
// NaN embeddings.
|
||||
std::vector<int32_t> valid_indexes;
|
||||
valid_indexes.reserve(chunk_speaker_samples_list_pair.second.size());
|
||||
|
||||
Matrix2D embeddings =
|
||||
ComputeEmbeddings(audio, n, chunk_speaker_samples_list_pair.second,
|
||||
std::move(callback), callback_arg);
|
||||
&valid_indexes, std::move(callback), callback_arg);
|
||||
|
||||
if (valid_indexes.size() != chunk_speaker_samples_list_pair.second.size()) {
|
||||
std::vector<Int32Pair> chunk_speaker_pair;
|
||||
std::vector<std::vector<Int32Pair>> sample_indexes;
|
||||
|
||||
chunk_speaker_pair.reserve(valid_indexes.size());
|
||||
sample_indexes.reserve(valid_indexes.size());
|
||||
for (auto i : valid_indexes) {
|
||||
chunk_speaker_pair.push_back(chunk_speaker_samples_list_pair.first[i]);
|
||||
sample_indexes.push_back(
|
||||
std::move(chunk_speaker_samples_list_pair.second[i]));
|
||||
}
|
||||
|
||||
chunk_speaker_samples_list_pair.first = std::move(chunk_speaker_pair);
|
||||
chunk_speaker_samples_list_pair.second = std::move(sample_indexes);
|
||||
}
|
||||
|
||||
std::vector<int32_t> cluster_labels = clustering_->Cluster(
|
||||
&embeddings(0, 0), embeddings.rows(), embeddings.cols());
|
||||
@@ -431,13 +455,17 @@ class OfflineSpeakerDiarizationPyannoteImpl
|
||||
Matrix2D ComputeEmbeddings(
|
||||
const float *audio, int32_t n,
|
||||
const std::vector<std::vector<Int32Pair>> &sample_indexes,
|
||||
std::vector<int32_t> *valid_indexes,
|
||||
OfflineSpeakerDiarizationProgressCallback callback,
|
||||
void *callback_arg) const {
|
||||
const auto &meta_data = segmentation_model_.GetModelMetaData();
|
||||
int32_t sample_rate = meta_data.sample_rate;
|
||||
Matrix2D ans(sample_indexes.size(), embedding_extractor_.Dim());
|
||||
|
||||
auto IsNaNWrapper = [](float f) -> bool { return std::isnan(f); };
|
||||
|
||||
int32_t k = 0;
|
||||
int32_t cur_row_index = 0;
|
||||
for (const auto &v : sample_indexes) {
|
||||
auto stream = embedding_extractor_.CreateStream();
|
||||
for (const auto &p : v) {
|
||||
@@ -459,7 +487,12 @@ class OfflineSpeakerDiarizationPyannoteImpl
|
||||
|
||||
std::vector<float> embedding = embedding_extractor_.Compute(stream.get());
|
||||
|
||||
std::copy(embedding.begin(), embedding.end(), &ans(k, 0));
|
||||
if (std::none_of(embedding.begin(), embedding.end(), IsNaNWrapper)) {
|
||||
// a valid embedding
|
||||
std::copy(embedding.begin(), embedding.end(), &ans(cur_row_index, 0));
|
||||
cur_row_index += 1;
|
||||
valid_indexes->push_back(k);
|
||||
}
|
||||
|
||||
k += 1;
|
||||
|
||||
@@ -468,6 +501,11 @@ class OfflineSpeakerDiarizationPyannoteImpl
|
||||
}
|
||||
}
|
||||
|
||||
if (k != cur_row_index) {
|
||||
auto seq = Eigen::seqN(0, cur_row_index);
|
||||
ans = ans(seq, Eigen::all);
|
||||
}
|
||||
|
||||
return ans;
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user