WebAssembly exmaple for speaker diarization (#1411)

This commit is contained in:
Fangjun Kuang
2024-10-10 22:14:45 +08:00
committed by GitHub
parent 67349b52f2
commit 1d061df355
37 changed files with 1008 additions and 24 deletions

View File

@@ -1749,6 +1749,20 @@ int32_t SherpaOnnxOfflineSpeakerDiarizationGetSampleRate(
return sd->impl->SampleRate();
}
void SherpaOnnxOfflineSpeakerDiarizationSetConfig(
const SherpaOnnxOfflineSpeakerDiarization *sd,
const SherpaOnnxOfflineSpeakerDiarizationConfig *config) {
sherpa_onnx::OfflineSpeakerDiarizationConfig sd_config;
sd_config.clustering.num_clusters =
SHERPA_ONNX_OR(config->clustering.num_clusters, -1);
sd_config.clustering.threshold =
SHERPA_ONNX_OR(config->clustering.threshold, 0.5);
sd->impl->SetConfig(sd_config);
}
int32_t SherpaOnnxOfflineSpeakerDiarizationResultGetNumSpeakers(
const SherpaOnnxOfflineSpeakerDiarizationResult *r) {
return r->impl.NumSpeakers();

View File

@@ -1449,6 +1449,11 @@ SHERPA_ONNX_API void SherpaOnnxDestroyOfflineSpeakerDiarization(
SHERPA_ONNX_API int32_t SherpaOnnxOfflineSpeakerDiarizationGetSampleRate(
const SherpaOnnxOfflineSpeakerDiarization *sd);
// Only config->clustering is used. All other fields are ignored
SHERPA_ONNX_API void SherpaOnnxOfflineSpeakerDiarizationSetConfig(
const SherpaOnnxOfflineSpeakerDiarization *sd,
const SherpaOnnxOfflineSpeakerDiarizationConfig *config);
SHERPA_ONNX_API typedef struct SherpaOnnxOfflineSpeakerDiarizationResult
SherpaOnnxOfflineSpeakerDiarizationResult;

View File

@@ -20,6 +20,10 @@ class OfflineSpeakerDiarizationImpl {
virtual int32_t SampleRate() const = 0;
// Note: Only config.clustering is used. All other fields in config are
// ignored
virtual void SetConfig(const OfflineSpeakerDiarizationConfig &config) = 0;
virtual OfflineSpeakerDiarizationResult Process(
const float *audio, int32_t n,
OfflineSpeakerDiarizationProgressCallback callback = nullptr,

View File

@@ -60,7 +60,7 @@ class OfflineSpeakerDiarizationPyannoteImpl
: config_(config),
segmentation_model_(config_.segmentation),
embedding_extractor_(config_.embedding),
clustering_(config_.clustering) {
clustering_(std::make_unique<FastClustering>(config_.clustering)) {
Init();
}
@@ -70,6 +70,15 @@ class OfflineSpeakerDiarizationPyannoteImpl
return meta_data.sample_rate;
}
void SetConfig(const OfflineSpeakerDiarizationConfig &config) override {
if (!config.clustering.Validate()) {
SHERPA_ONNX_LOGE("Invalid clustering config. Skip it");
return;
}
clustering_ = std::make_unique<FastClustering>(config.clustering);
config_.clustering = config.clustering;
}
OfflineSpeakerDiarizationResult Process(
const float *audio, int32_t n,
OfflineSpeakerDiarizationProgressCallback callback = nullptr,
@@ -105,7 +114,7 @@ class OfflineSpeakerDiarizationPyannoteImpl
ComputeEmbeddings(audio, n, chunk_speaker_samples_list_pair.second,
std::move(callback), callback_arg);
std::vector<int32_t> cluster_labels = clustering_.Cluster(
std::vector<int32_t> cluster_labels = clustering_->Cluster(
&embeddings(0, 0), embeddings.rows(), embeddings.cols());
int32_t max_cluster_index =
@@ -636,7 +645,7 @@ class OfflineSpeakerDiarizationPyannoteImpl
OfflineSpeakerDiarizationConfig config_;
OfflineSpeakerSegmentationPyannoteModel segmentation_model_;
SpeakerEmbeddingExtractor embedding_extractor_;
FastClustering clustering_;
std::unique_ptr<FastClustering> clustering_;
Matrix2DInt32 powerset_mapping_;
};

View File

@@ -79,6 +79,11 @@ int32_t OfflineSpeakerDiarization::SampleRate() const {
return impl_->SampleRate();
}
void OfflineSpeakerDiarization::SetConfig(
const OfflineSpeakerDiarizationConfig &config) {
impl_->SetConfig(config);
}
OfflineSpeakerDiarizationResult OfflineSpeakerDiarization::Process(
const float *audio, int32_t n,
OfflineSpeakerDiarizationProgressCallback callback /*= nullptr*/,

View File

@@ -62,6 +62,10 @@ class OfflineSpeakerDiarization {
// Expected sample rate of the input audio samples
int32_t SampleRate() const;
// Note: Only config.clustering is used. All other fields in config are
// ignored
void SetConfig(const OfflineSpeakerDiarizationConfig &config);
OfflineSpeakerDiarizationResult Process(
const float *audio, int32_t n,
OfflineSpeakerDiarizationProgressCallback callback = nullptr,

View File

@@ -68,6 +68,7 @@ void PybindOfflineSpeakerDiarization(py::module *m) {
.def(py::init<const OfflineSpeakerDiarizationConfig &>(),
py::arg("config"))
.def_property_readonly("sample_rate", &PyClass::SampleRate)
.def("set_config", &PyClass::SetConfig, py::arg("config"))
.def(
"process",
[](const PyClass &self, const std::vector<float> samples,