Python API for speaker diarization. (#1400)

This commit is contained in:
Fangjun Kuang
2024-10-09 14:13:26 +08:00
committed by GitHub
parent 59407edcad
commit 8535b1d3bb
14 changed files with 315 additions and 9 deletions

View File

@@ -62,6 +62,8 @@ endif()
if(SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION)
list(APPEND srcs
fast-clustering.cc
offline-speaker-diarization-result.cc
offline-speaker-diarization.cc
)
endif()

View File

@@ -0,0 +1,32 @@
// sherpa-onnx/python/csrc/offline-speaker-diarization-result.cc
//
// Copyright (c) 2024 Xiaomi Corporation
#include "sherpa-onnx/python/csrc/offline-speaker-diarization-result.h"
#include "sherpa-onnx/csrc/offline-speaker-diarization-result.h"
namespace sherpa_onnx {
static void PybindOfflineSpeakerDiarizationSegment(py::module *m) {
using PyClass = OfflineSpeakerDiarizationSegment;
py::class_<PyClass>(*m, "OfflineSpeakerDiarizationSegment")
.def_property_readonly("start", &PyClass::Start)
.def_property_readonly("end", &PyClass::End)
.def_property_readonly("duration", &PyClass::Duration)
.def_property_readonly("speaker", &PyClass::Speaker)
.def_property("text", &PyClass::Text, &PyClass::SetText)
.def("__str__", &PyClass::ToString);
}
void PybindOfflineSpeakerDiarizationResult(py::module *m) {
PybindOfflineSpeakerDiarizationSegment(m);
using PyClass = OfflineSpeakerDiarizationResult;
py::class_<PyClass>(*m, "OfflineSpeakerDiarizationResult")
.def_property_readonly("num_speakers", &PyClass::NumSpeakers)
.def_property_readonly("num_segments", &PyClass::NumSegments)
.def("sort_by_start_time", &PyClass::SortByStartTime)
.def("sort_by_speaker", &PyClass::SortBySpeaker);
}
} // namespace sherpa_onnx

View File

@@ -0,0 +1,16 @@
// sherpa-onnx/python/csrc/offline-speaker-diarization-result.h
//
// Copyright (c) 2024 Xiaomi Corporation
#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SPEAKER_DIARIZATION_RESULT_H_
#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SPEAKER_DIARIZATION_RESULT_H_
#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
namespace sherpa_onnx {
void PybindOfflineSpeakerDiarizationResult(py::module *m);
}
#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SPEAKER_DIARIZATION_RESULT_H_

View File

@@ -0,0 +1,92 @@
// sherpa-onnx/python/csrc/offline-speaker-diarization.cc
//
// Copyright (c) 2024 Xiaomi Corporation
#include "sherpa-onnx/python/csrc/offline-speaker-diarization.h"
#include <string>
#include <vector>
#include "sherpa-onnx/csrc/offline-speaker-diarization.h"
#include "sherpa-onnx/csrc/offline-speaker-segmentation-model-config.h"
#include "sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model-config.h"
namespace sherpa_onnx {
static void PybindOfflineSpeakerSegmentationPyannoteModelConfig(py::module *m) {
using PyClass = OfflineSpeakerSegmentationPyannoteModelConfig;
py::class_<PyClass>(*m, "OfflineSpeakerSegmentationPyannoteModelConfig")
.def(py::init<>())
.def(py::init<const std::string &>(), py::arg("model"))
.def_readwrite("model", &PyClass::model)
.def("__str__", &PyClass::ToString)
.def("validate", &PyClass::Validate);
}
static void PybindOfflineSpeakerSegmentationModelConfig(py::module *m) {
PybindOfflineSpeakerSegmentationPyannoteModelConfig(m);
using PyClass = OfflineSpeakerSegmentationModelConfig;
py::class_<PyClass>(*m, "OfflineSpeakerSegmentationModelConfig")
.def(py::init<>())
.def(py::init<const OfflineSpeakerSegmentationPyannoteModelConfig &,
int32_t, bool, const std::string &>(),
py::arg("pyannote"), py::arg("num_threads") = 1,
py::arg("debug") = false, py::arg("provider") = "cpu")
.def_readwrite("pyannote", &PyClass::pyannote)
.def_readwrite("num_threads", &PyClass::num_threads)
.def_readwrite("debug", &PyClass::debug)
.def_readwrite("provider", &PyClass::provider)
.def("__str__", &PyClass::ToString)
.def("validate", &PyClass::Validate);
}
static void PybindOfflineSpeakerDiarizationConfig(py::module *m) {
PybindOfflineSpeakerSegmentationModelConfig(m);
using PyClass = OfflineSpeakerDiarizationConfig;
py::class_<PyClass>(*m, "OfflineSpeakerDiarizationConfig")
.def(py::init<const OfflineSpeakerSegmentationModelConfig &,
const SpeakerEmbeddingExtractorConfig &,
const FastClusteringConfig &, float, float>(),
py::arg("segmentation"), py::arg("embedding"), py::arg("clustering"),
py::arg("min_duration_on") = 0.3, py::arg("min_duration_off") = 0.5)
.def_readwrite("segmentation", &PyClass::segmentation)
.def_readwrite("embedding", &PyClass::embedding)
.def_readwrite("clustering", &PyClass::clustering)
.def_readwrite("min_duration_on", &PyClass::min_duration_on)
.def_readwrite("min_duration_off", &PyClass::min_duration_off)
.def("__str__", &PyClass::ToString)
.def("validate", &PyClass::Validate);
}
void PybindOfflineSpeakerDiarization(py::module *m) {
PybindOfflineSpeakerDiarizationConfig(m);
using PyClass = OfflineSpeakerDiarization;
py::class_<PyClass>(*m, "OfflineSpeakerDiarization")
.def(py::init<const OfflineSpeakerDiarizationConfig &>(),
py::arg("config"))
.def_property_readonly("sample_rate", &PyClass::SampleRate)
.def(
"process",
[](const PyClass &self, const std::vector<float> samples,
std::function<int32_t(int32_t, int32_t)> callback) {
if (!callback) {
return self.Process(samples.data(), samples.size());
}
std::function<int32_t(int32_t, int32_t, void *)> callback_wrapper =
[callback](int32_t processed_chunks, int32_t num_chunks,
void *) -> int32_t {
callback(processed_chunks, num_chunks);
return 0;
};
return self.Process(samples.data(), samples.size(),
callback_wrapper);
},
py::arg("samples"), py::arg("callback") = py::none());
}
} // namespace sherpa_onnx

View File

@@ -0,0 +1,16 @@
// sherpa-onnx/python/csrc/offline-speaker-diarization.h
//
// Copyright (c) 2024 Xiaomi Corporation
#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SPEAKER_DIARIZATION_H_
#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SPEAKER_DIARIZATION_H_
#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
namespace sherpa_onnx {
void PybindOfflineSpeakerDiarization(py::module *m);
}
#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SPEAKER_DIARIZATION_H_

View File

@@ -37,6 +37,8 @@
#if SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION == 1
#include "sherpa-onnx/python/csrc/fast-clustering.h"
#include "sherpa-onnx/python/csrc/offline-speaker-diarization-result.h"
#include "sherpa-onnx/python/csrc/offline-speaker-diarization.h"
#endif
namespace sherpa_onnx {
@@ -74,14 +76,16 @@ PYBIND11_MODULE(_sherpa_onnx, m) {
PybindOfflineTts(&m);
#endif
#if SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION == 1
PybindFastClustering(&m);
#endif
PybindSpeakerEmbeddingExtractor(&m);
PybindSpeakerEmbeddingManager(&m);
PybindSpokenLanguageIdentification(&m);
#if SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION == 1
PybindFastClustering(&m);
PybindOfflineSpeakerDiarizationResult(&m);
PybindOfflineSpeakerDiarization(&m);
#endif
PybindAlsa(&m);
}