WebAssembly exmaple for speaker diarization (#1411)

This commit is contained in:
Fangjun Kuang
2024-10-10 22:14:45 +08:00
committed by GitHub
parent 67349b52f2
commit 1d061df355
37 changed files with 1008 additions and 24 deletions

View File

@@ -20,6 +20,10 @@ class OfflineSpeakerDiarizationImpl {
virtual int32_t SampleRate() const = 0;
// Note: Only config.clustering is used. All other fields in config are
// ignored
virtual void SetConfig(const OfflineSpeakerDiarizationConfig &config) = 0;
virtual OfflineSpeakerDiarizationResult Process(
const float *audio, int32_t n,
OfflineSpeakerDiarizationProgressCallback callback = nullptr,

View File

@@ -60,7 +60,7 @@ class OfflineSpeakerDiarizationPyannoteImpl
: config_(config),
segmentation_model_(config_.segmentation),
embedding_extractor_(config_.embedding),
clustering_(config_.clustering) {
clustering_(std::make_unique<FastClustering>(config_.clustering)) {
Init();
}
@@ -70,6 +70,15 @@ class OfflineSpeakerDiarizationPyannoteImpl
return meta_data.sample_rate;
}
void SetConfig(const OfflineSpeakerDiarizationConfig &config) override {
if (!config.clustering.Validate()) {
SHERPA_ONNX_LOGE("Invalid clustering config. Skip it");
return;
}
clustering_ = std::make_unique<FastClustering>(config.clustering);
config_.clustering = config.clustering;
}
OfflineSpeakerDiarizationResult Process(
const float *audio, int32_t n,
OfflineSpeakerDiarizationProgressCallback callback = nullptr,
@@ -105,7 +114,7 @@ class OfflineSpeakerDiarizationPyannoteImpl
ComputeEmbeddings(audio, n, chunk_speaker_samples_list_pair.second,
std::move(callback), callback_arg);
std::vector<int32_t> cluster_labels = clustering_.Cluster(
std::vector<int32_t> cluster_labels = clustering_->Cluster(
&embeddings(0, 0), embeddings.rows(), embeddings.cols());
int32_t max_cluster_index =
@@ -636,7 +645,7 @@ class OfflineSpeakerDiarizationPyannoteImpl
OfflineSpeakerDiarizationConfig config_;
OfflineSpeakerSegmentationPyannoteModel segmentation_model_;
SpeakerEmbeddingExtractor embedding_extractor_;
FastClustering clustering_;
std::unique_ptr<FastClustering> clustering_;
Matrix2DInt32 powerset_mapping_;
};

View File

@@ -79,6 +79,11 @@ int32_t OfflineSpeakerDiarization::SampleRate() const {
return impl_->SampleRate();
}
void OfflineSpeakerDiarization::SetConfig(
const OfflineSpeakerDiarizationConfig &config) {
impl_->SetConfig(config);
}
OfflineSpeakerDiarizationResult OfflineSpeakerDiarization::Process(
const float *audio, int32_t n,
OfflineSpeakerDiarizationProgressCallback callback /*= nullptr*/,

View File

@@ -62,6 +62,10 @@ class OfflineSpeakerDiarization {
// Expected sample rate of the input audio samples
int32_t SampleRate() const;
// Note: Only config.clustering is used. All other fields in config are
// ignored
void SetConfig(const OfflineSpeakerDiarizationConfig &config);
OfflineSpeakerDiarizationResult Process(
const float *audio, int32_t n,
OfflineSpeakerDiarizationProgressCallback callback = nullptr,