WebAssembly exmaple for speaker diarization (#1411)
This commit is contained in:
@@ -20,6 +20,10 @@ class OfflineSpeakerDiarizationImpl {
|
||||
|
||||
virtual int32_t SampleRate() const = 0;
|
||||
|
||||
// Note: Only config.clustering is used. All other fields in config are
|
||||
// ignored
|
||||
virtual void SetConfig(const OfflineSpeakerDiarizationConfig &config) = 0;
|
||||
|
||||
virtual OfflineSpeakerDiarizationResult Process(
|
||||
const float *audio, int32_t n,
|
||||
OfflineSpeakerDiarizationProgressCallback callback = nullptr,
|
||||
|
||||
@@ -60,7 +60,7 @@ class OfflineSpeakerDiarizationPyannoteImpl
|
||||
: config_(config),
|
||||
segmentation_model_(config_.segmentation),
|
||||
embedding_extractor_(config_.embedding),
|
||||
clustering_(config_.clustering) {
|
||||
clustering_(std::make_unique<FastClustering>(config_.clustering)) {
|
||||
Init();
|
||||
}
|
||||
|
||||
@@ -70,6 +70,15 @@ class OfflineSpeakerDiarizationPyannoteImpl
|
||||
return meta_data.sample_rate;
|
||||
}
|
||||
|
||||
void SetConfig(const OfflineSpeakerDiarizationConfig &config) override {
|
||||
if (!config.clustering.Validate()) {
|
||||
SHERPA_ONNX_LOGE("Invalid clustering config. Skip it");
|
||||
return;
|
||||
}
|
||||
clustering_ = std::make_unique<FastClustering>(config.clustering);
|
||||
config_.clustering = config.clustering;
|
||||
}
|
||||
|
||||
OfflineSpeakerDiarizationResult Process(
|
||||
const float *audio, int32_t n,
|
||||
OfflineSpeakerDiarizationProgressCallback callback = nullptr,
|
||||
@@ -105,7 +114,7 @@ class OfflineSpeakerDiarizationPyannoteImpl
|
||||
ComputeEmbeddings(audio, n, chunk_speaker_samples_list_pair.second,
|
||||
std::move(callback), callback_arg);
|
||||
|
||||
std::vector<int32_t> cluster_labels = clustering_.Cluster(
|
||||
std::vector<int32_t> cluster_labels = clustering_->Cluster(
|
||||
&embeddings(0, 0), embeddings.rows(), embeddings.cols());
|
||||
|
||||
int32_t max_cluster_index =
|
||||
@@ -636,7 +645,7 @@ class OfflineSpeakerDiarizationPyannoteImpl
|
||||
OfflineSpeakerDiarizationConfig config_;
|
||||
OfflineSpeakerSegmentationPyannoteModel segmentation_model_;
|
||||
SpeakerEmbeddingExtractor embedding_extractor_;
|
||||
FastClustering clustering_;
|
||||
std::unique_ptr<FastClustering> clustering_;
|
||||
Matrix2DInt32 powerset_mapping_;
|
||||
};
|
||||
|
||||
|
||||
@@ -79,6 +79,11 @@ int32_t OfflineSpeakerDiarization::SampleRate() const {
|
||||
return impl_->SampleRate();
|
||||
}
|
||||
|
||||
void OfflineSpeakerDiarization::SetConfig(
|
||||
const OfflineSpeakerDiarizationConfig &config) {
|
||||
impl_->SetConfig(config);
|
||||
}
|
||||
|
||||
OfflineSpeakerDiarizationResult OfflineSpeakerDiarization::Process(
|
||||
const float *audio, int32_t n,
|
||||
OfflineSpeakerDiarizationProgressCallback callback /*= nullptr*/,
|
||||
|
||||
@@ -62,6 +62,10 @@ class OfflineSpeakerDiarization {
|
||||
// Expected sample rate of the input audio samples
|
||||
int32_t SampleRate() const;
|
||||
|
||||
// Note: Only config.clustering is used. All other fields in config are
|
||||
// ignored
|
||||
void SetConfig(const OfflineSpeakerDiarizationConfig &config);
|
||||
|
||||
OfflineSpeakerDiarizationResult Process(
|
||||
const float *audio, int32_t n,
|
||||
OfflineSpeakerDiarizationProgressCallback callback = nullptr,
|
||||
|
||||
Reference in New Issue
Block a user