Handle NaN embeddings in speaker diarization. (#1461)
See also https://github.com/thewh1teagle/sherpa-rs/issues/33
This commit is contained in:
@@ -19,7 +19,7 @@
|
|||||||
#include "sherpa-onnx/c-api/cxx-api.h"
|
#include "sherpa-onnx/c-api/cxx-api.h"
|
||||||
|
|
||||||
int32_t main() {
|
int32_t main() {
|
||||||
using namespace sherpa_onnx::cxx;
|
using namespace sherpa_onnx::cxx; // NOLINT
|
||||||
OfflineRecognizerConfig config;
|
OfflineRecognizerConfig config;
|
||||||
|
|
||||||
config.model_config.sense_voice.model =
|
config.model_config.sense_voice.model =
|
||||||
|
|||||||
@@ -20,7 +20,7 @@
|
|||||||
#include "sherpa-onnx/c-api/cxx-api.h"
|
#include "sherpa-onnx/c-api/cxx-api.h"
|
||||||
|
|
||||||
int32_t main() {
|
int32_t main() {
|
||||||
using namespace sherpa_onnx::cxx;
|
using namespace sherpa_onnx::cxx; // NOLINT
|
||||||
OnlineRecognizerConfig config;
|
OnlineRecognizerConfig config;
|
||||||
|
|
||||||
// please see
|
// please see
|
||||||
|
|||||||
@@ -19,7 +19,7 @@
|
|||||||
#include "sherpa-onnx/c-api/cxx-api.h"
|
#include "sherpa-onnx/c-api/cxx-api.h"
|
||||||
|
|
||||||
int32_t main() {
|
int32_t main() {
|
||||||
using namespace sherpa_onnx::cxx;
|
using namespace sherpa_onnx::cxx; // NOLINT
|
||||||
OfflineRecognizerConfig config;
|
OfflineRecognizerConfig config;
|
||||||
|
|
||||||
config.model_config.whisper.encoder =
|
config.model_config.whisper.encoder =
|
||||||
|
|||||||
@@ -71,6 +71,9 @@ function is_source_code_file() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
function check_style() {
|
function check_style() {
|
||||||
|
if [[ $1 == mfc-example* ]]; then
|
||||||
|
return
|
||||||
|
fi
|
||||||
python3 $cpplint_src $1 || abort $1
|
python3 $cpplint_src $1 || abort $1
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -99,7 +102,7 @@ function do_check() {
|
|||||||
;;
|
;;
|
||||||
2)
|
2)
|
||||||
echo "Check all files"
|
echo "Check all files"
|
||||||
files=$(find $sherpa_onnx_dir/sherpa-onnx/csrc $sherpa_onnx_dir/sherpa-onnx/python $sherpa_onnx_dir/scripts/node-addon-api/src $sherpa_onnx_dir/sherpa-onnx/jni $sherpa_onnx_dir/sherpa-onnx/c-api -name "*.h" -o -name "*.cc")
|
files=$(find $sherpa_onnx_dir/cxx-api-examples $sherpa_onnx_dir/c-api-examples $sherpa_onnx_dir/sherpa-onnx/csrc $sherpa_onnx_dir/sherpa-onnx/python $sherpa_onnx_dir/scripts/node-addon-api/src $sherpa_onnx_dir/sherpa-onnx/jni $sherpa_onnx_dir/sherpa-onnx/c-api -name "*.h" -o -name "*.cc")
|
||||||
;;
|
;;
|
||||||
*)
|
*)
|
||||||
echo "Check last commit"
|
echo "Check last commit"
|
||||||
|
|||||||
@@ -5,6 +5,7 @@
|
|||||||
#define SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_DIARIZATION_PYANNOTE_IMPL_H_
|
#define SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_DIARIZATION_PYANNOTE_IMPL_H_
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
#include <cmath>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
@@ -135,9 +136,32 @@ class OfflineSpeakerDiarizationPyannoteImpl
|
|||||||
}
|
}
|
||||||
|
|
||||||
auto chunk_speaker_samples_list_pair = GetChunkSpeakerSampleIndexes(labels);
|
auto chunk_speaker_samples_list_pair = GetChunkSpeakerSampleIndexes(labels);
|
||||||
|
|
||||||
|
// The embedding model may output NaN. valid_indexes contains indexes
|
||||||
|
// in chunk_speaker_samples_list_pair.second that don't lead to
|
||||||
|
// NaN embeddings.
|
||||||
|
std::vector<int32_t> valid_indexes;
|
||||||
|
valid_indexes.reserve(chunk_speaker_samples_list_pair.second.size());
|
||||||
|
|
||||||
Matrix2D embeddings =
|
Matrix2D embeddings =
|
||||||
ComputeEmbeddings(audio, n, chunk_speaker_samples_list_pair.second,
|
ComputeEmbeddings(audio, n, chunk_speaker_samples_list_pair.second,
|
||||||
std::move(callback), callback_arg);
|
&valid_indexes, std::move(callback), callback_arg);
|
||||||
|
|
||||||
|
if (valid_indexes.size() != chunk_speaker_samples_list_pair.second.size()) {
|
||||||
|
std::vector<Int32Pair> chunk_speaker_pair;
|
||||||
|
std::vector<std::vector<Int32Pair>> sample_indexes;
|
||||||
|
|
||||||
|
chunk_speaker_pair.reserve(valid_indexes.size());
|
||||||
|
sample_indexes.reserve(valid_indexes.size());
|
||||||
|
for (auto i : valid_indexes) {
|
||||||
|
chunk_speaker_pair.push_back(chunk_speaker_samples_list_pair.first[i]);
|
||||||
|
sample_indexes.push_back(
|
||||||
|
std::move(chunk_speaker_samples_list_pair.second[i]));
|
||||||
|
}
|
||||||
|
|
||||||
|
chunk_speaker_samples_list_pair.first = std::move(chunk_speaker_pair);
|
||||||
|
chunk_speaker_samples_list_pair.second = std::move(sample_indexes);
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<int32_t> cluster_labels = clustering_->Cluster(
|
std::vector<int32_t> cluster_labels = clustering_->Cluster(
|
||||||
&embeddings(0, 0), embeddings.rows(), embeddings.cols());
|
&embeddings(0, 0), embeddings.rows(), embeddings.cols());
|
||||||
@@ -431,13 +455,17 @@ class OfflineSpeakerDiarizationPyannoteImpl
|
|||||||
Matrix2D ComputeEmbeddings(
|
Matrix2D ComputeEmbeddings(
|
||||||
const float *audio, int32_t n,
|
const float *audio, int32_t n,
|
||||||
const std::vector<std::vector<Int32Pair>> &sample_indexes,
|
const std::vector<std::vector<Int32Pair>> &sample_indexes,
|
||||||
|
std::vector<int32_t> *valid_indexes,
|
||||||
OfflineSpeakerDiarizationProgressCallback callback,
|
OfflineSpeakerDiarizationProgressCallback callback,
|
||||||
void *callback_arg) const {
|
void *callback_arg) const {
|
||||||
const auto &meta_data = segmentation_model_.GetModelMetaData();
|
const auto &meta_data = segmentation_model_.GetModelMetaData();
|
||||||
int32_t sample_rate = meta_data.sample_rate;
|
int32_t sample_rate = meta_data.sample_rate;
|
||||||
Matrix2D ans(sample_indexes.size(), embedding_extractor_.Dim());
|
Matrix2D ans(sample_indexes.size(), embedding_extractor_.Dim());
|
||||||
|
|
||||||
|
auto IsNaNWrapper = [](float f) -> bool { return std::isnan(f); };
|
||||||
|
|
||||||
int32_t k = 0;
|
int32_t k = 0;
|
||||||
|
int32_t cur_row_index = 0;
|
||||||
for (const auto &v : sample_indexes) {
|
for (const auto &v : sample_indexes) {
|
||||||
auto stream = embedding_extractor_.CreateStream();
|
auto stream = embedding_extractor_.CreateStream();
|
||||||
for (const auto &p : v) {
|
for (const auto &p : v) {
|
||||||
@@ -459,7 +487,12 @@ class OfflineSpeakerDiarizationPyannoteImpl
|
|||||||
|
|
||||||
std::vector<float> embedding = embedding_extractor_.Compute(stream.get());
|
std::vector<float> embedding = embedding_extractor_.Compute(stream.get());
|
||||||
|
|
||||||
std::copy(embedding.begin(), embedding.end(), &ans(k, 0));
|
if (std::none_of(embedding.begin(), embedding.end(), IsNaNWrapper)) {
|
||||||
|
// a valid embedding
|
||||||
|
std::copy(embedding.begin(), embedding.end(), &ans(cur_row_index, 0));
|
||||||
|
cur_row_index += 1;
|
||||||
|
valid_indexes->push_back(k);
|
||||||
|
}
|
||||||
|
|
||||||
k += 1;
|
k += 1;
|
||||||
|
|
||||||
@@ -468,6 +501,11 @@ class OfflineSpeakerDiarizationPyannoteImpl
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (k != cur_row_index) {
|
||||||
|
auto seq = Eigen::seqN(0, cur_row_index);
|
||||||
|
ans = ans(seq, Eigen::all);
|
||||||
|
}
|
||||||
|
|
||||||
return ans;
|
return ans;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -122,7 +122,7 @@ class SpeakerEmbeddingExtractorNeMoImpl : public SpeakerEmbeddingExtractorImpl {
|
|||||||
auto variance = EX2 - EX.array().pow(2);
|
auto variance = EX2 - EX.array().pow(2);
|
||||||
auto stddev = variance.array().sqrt();
|
auto stddev = variance.array().sqrt();
|
||||||
|
|
||||||
m = (m.rowwise() - EX).array().rowwise() / stddev.array();
|
m = (m.rowwise() - EX).array().rowwise() / (stddev.array() + 1e-5);
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|||||||
Reference in New Issue
Block a user