WebAssembly exmaple for speaker diarization (#1411)
This commit is contained in:
@@ -1749,6 +1749,20 @@ int32_t SherpaOnnxOfflineSpeakerDiarizationGetSampleRate(
|
||||
return sd->impl->SampleRate();
|
||||
}
|
||||
|
||||
void SherpaOnnxOfflineSpeakerDiarizationSetConfig(
|
||||
const SherpaOnnxOfflineSpeakerDiarization *sd,
|
||||
const SherpaOnnxOfflineSpeakerDiarizationConfig *config) {
|
||||
sherpa_onnx::OfflineSpeakerDiarizationConfig sd_config;
|
||||
|
||||
sd_config.clustering.num_clusters =
|
||||
SHERPA_ONNX_OR(config->clustering.num_clusters, -1);
|
||||
|
||||
sd_config.clustering.threshold =
|
||||
SHERPA_ONNX_OR(config->clustering.threshold, 0.5);
|
||||
|
||||
sd->impl->SetConfig(sd_config);
|
||||
}
|
||||
|
||||
int32_t SherpaOnnxOfflineSpeakerDiarizationResultGetNumSpeakers(
|
||||
const SherpaOnnxOfflineSpeakerDiarizationResult *r) {
|
||||
return r->impl.NumSpeakers();
|
||||
|
||||
@@ -1449,6 +1449,11 @@ SHERPA_ONNX_API void SherpaOnnxDestroyOfflineSpeakerDiarization(
|
||||
SHERPA_ONNX_API int32_t SherpaOnnxOfflineSpeakerDiarizationGetSampleRate(
|
||||
const SherpaOnnxOfflineSpeakerDiarization *sd);
|
||||
|
||||
// Only config->clustering is used. All other fields are ignored
|
||||
SHERPA_ONNX_API void SherpaOnnxOfflineSpeakerDiarizationSetConfig(
|
||||
const SherpaOnnxOfflineSpeakerDiarization *sd,
|
||||
const SherpaOnnxOfflineSpeakerDiarizationConfig *config);
|
||||
|
||||
SHERPA_ONNX_API typedef struct SherpaOnnxOfflineSpeakerDiarizationResult
|
||||
SherpaOnnxOfflineSpeakerDiarizationResult;
|
||||
|
||||
|
||||
@@ -20,6 +20,10 @@ class OfflineSpeakerDiarizationImpl {
|
||||
|
||||
virtual int32_t SampleRate() const = 0;
|
||||
|
||||
// Note: Only config.clustering is used. All other fields in config are
|
||||
// ignored
|
||||
virtual void SetConfig(const OfflineSpeakerDiarizationConfig &config) = 0;
|
||||
|
||||
virtual OfflineSpeakerDiarizationResult Process(
|
||||
const float *audio, int32_t n,
|
||||
OfflineSpeakerDiarizationProgressCallback callback = nullptr,
|
||||
|
||||
@@ -60,7 +60,7 @@ class OfflineSpeakerDiarizationPyannoteImpl
|
||||
: config_(config),
|
||||
segmentation_model_(config_.segmentation),
|
||||
embedding_extractor_(config_.embedding),
|
||||
clustering_(config_.clustering) {
|
||||
clustering_(std::make_unique<FastClustering>(config_.clustering)) {
|
||||
Init();
|
||||
}
|
||||
|
||||
@@ -70,6 +70,15 @@ class OfflineSpeakerDiarizationPyannoteImpl
|
||||
return meta_data.sample_rate;
|
||||
}
|
||||
|
||||
void SetConfig(const OfflineSpeakerDiarizationConfig &config) override {
|
||||
if (!config.clustering.Validate()) {
|
||||
SHERPA_ONNX_LOGE("Invalid clustering config. Skip it");
|
||||
return;
|
||||
}
|
||||
clustering_ = std::make_unique<FastClustering>(config.clustering);
|
||||
config_.clustering = config.clustering;
|
||||
}
|
||||
|
||||
OfflineSpeakerDiarizationResult Process(
|
||||
const float *audio, int32_t n,
|
||||
OfflineSpeakerDiarizationProgressCallback callback = nullptr,
|
||||
@@ -105,7 +114,7 @@ class OfflineSpeakerDiarizationPyannoteImpl
|
||||
ComputeEmbeddings(audio, n, chunk_speaker_samples_list_pair.second,
|
||||
std::move(callback), callback_arg);
|
||||
|
||||
std::vector<int32_t> cluster_labels = clustering_.Cluster(
|
||||
std::vector<int32_t> cluster_labels = clustering_->Cluster(
|
||||
&embeddings(0, 0), embeddings.rows(), embeddings.cols());
|
||||
|
||||
int32_t max_cluster_index =
|
||||
@@ -636,7 +645,7 @@ class OfflineSpeakerDiarizationPyannoteImpl
|
||||
OfflineSpeakerDiarizationConfig config_;
|
||||
OfflineSpeakerSegmentationPyannoteModel segmentation_model_;
|
||||
SpeakerEmbeddingExtractor embedding_extractor_;
|
||||
FastClustering clustering_;
|
||||
std::unique_ptr<FastClustering> clustering_;
|
||||
Matrix2DInt32 powerset_mapping_;
|
||||
};
|
||||
|
||||
|
||||
@@ -79,6 +79,11 @@ int32_t OfflineSpeakerDiarization::SampleRate() const {
|
||||
return impl_->SampleRate();
|
||||
}
|
||||
|
||||
void OfflineSpeakerDiarization::SetConfig(
|
||||
const OfflineSpeakerDiarizationConfig &config) {
|
||||
impl_->SetConfig(config);
|
||||
}
|
||||
|
||||
OfflineSpeakerDiarizationResult OfflineSpeakerDiarization::Process(
|
||||
const float *audio, int32_t n,
|
||||
OfflineSpeakerDiarizationProgressCallback callback /*= nullptr*/,
|
||||
|
||||
@@ -62,6 +62,10 @@ class OfflineSpeakerDiarization {
|
||||
// Expected sample rate of the input audio samples
|
||||
int32_t SampleRate() const;
|
||||
|
||||
// Note: Only config.clustering is used. All other fields in config are
|
||||
// ignored
|
||||
void SetConfig(const OfflineSpeakerDiarizationConfig &config);
|
||||
|
||||
OfflineSpeakerDiarizationResult Process(
|
||||
const float *audio, int32_t n,
|
||||
OfflineSpeakerDiarizationProgressCallback callback = nullptr,
|
||||
|
||||
@@ -68,6 +68,7 @@ void PybindOfflineSpeakerDiarization(py::module *m) {
|
||||
.def(py::init<const OfflineSpeakerDiarizationConfig &>(),
|
||||
py::arg("config"))
|
||||
.def_property_readonly("sample_rate", &PyClass::SampleRate)
|
||||
.def("set_config", &PyClass::SetConfig, py::arg("config"))
|
||||
.def(
|
||||
"process",
|
||||
[](const PyClass &self, const std::vector<float> samples,
|
||||
|
||||
Reference in New Issue
Block a user