Python API for speaker diarization. (#1400)
This commit is contained in:
@@ -62,6 +62,8 @@ endif()
|
||||
if(SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION)
|
||||
list(APPEND srcs
|
||||
fast-clustering.cc
|
||||
offline-speaker-diarization-result.cc
|
||||
offline-speaker-diarization.cc
|
||||
)
|
||||
endif()
|
||||
|
||||
|
||||
@@ -0,0 +1,32 @@
|
||||
// sherpa-onnx/python/csrc/offline-speaker-diarization-result.cc
|
||||
//
|
||||
// Copyright (c) 2024 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/python/csrc/offline-speaker-diarization-result.h"
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-speaker-diarization-result.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
static void PybindOfflineSpeakerDiarizationSegment(py::module *m) {
|
||||
using PyClass = OfflineSpeakerDiarizationSegment;
|
||||
py::class_<PyClass>(*m, "OfflineSpeakerDiarizationSegment")
|
||||
.def_property_readonly("start", &PyClass::Start)
|
||||
.def_property_readonly("end", &PyClass::End)
|
||||
.def_property_readonly("duration", &PyClass::Duration)
|
||||
.def_property_readonly("speaker", &PyClass::Speaker)
|
||||
.def_property("text", &PyClass::Text, &PyClass::SetText)
|
||||
.def("__str__", &PyClass::ToString);
|
||||
}
|
||||
|
||||
void PybindOfflineSpeakerDiarizationResult(py::module *m) {
|
||||
PybindOfflineSpeakerDiarizationSegment(m);
|
||||
using PyClass = OfflineSpeakerDiarizationResult;
|
||||
py::class_<PyClass>(*m, "OfflineSpeakerDiarizationResult")
|
||||
.def_property_readonly("num_speakers", &PyClass::NumSpeakers)
|
||||
.def_property_readonly("num_segments", &PyClass::NumSegments)
|
||||
.def("sort_by_start_time", &PyClass::SortByStartTime)
|
||||
.def("sort_by_speaker", &PyClass::SortBySpeaker);
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
16
sherpa-onnx/python/csrc/offline-speaker-diarization-result.h
Normal file
16
sherpa-onnx/python/csrc/offline-speaker-diarization-result.h
Normal file
@@ -0,0 +1,16 @@
|
||||
// sherpa-onnx/python/csrc/offline-speaker-diarization-result.h
|
||||
//
|
||||
// Copyright (c) 2024 Xiaomi Corporation
|
||||
|
||||
#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SPEAKER_DIARIZATION_RESULT_H_
|
||||
#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SPEAKER_DIARIZATION_RESULT_H_
|
||||
|
||||
#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
void PybindOfflineSpeakerDiarizationResult(py::module *m);
|
||||
|
||||
}
|
||||
|
||||
#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SPEAKER_DIARIZATION_RESULT_H_
|
||||
92
sherpa-onnx/python/csrc/offline-speaker-diarization.cc
Normal file
92
sherpa-onnx/python/csrc/offline-speaker-diarization.cc
Normal file
@@ -0,0 +1,92 @@
|
||||
// sherpa-onnx/python/csrc/offline-speaker-diarization.cc
|
||||
//
|
||||
// Copyright (c) 2024 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/python/csrc/offline-speaker-diarization.h"
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-speaker-diarization.h"
|
||||
#include "sherpa-onnx/csrc/offline-speaker-segmentation-model-config.h"
|
||||
#include "sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model-config.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
static void PybindOfflineSpeakerSegmentationPyannoteModelConfig(py::module *m) {
|
||||
using PyClass = OfflineSpeakerSegmentationPyannoteModelConfig;
|
||||
py::class_<PyClass>(*m, "OfflineSpeakerSegmentationPyannoteModelConfig")
|
||||
.def(py::init<>())
|
||||
.def(py::init<const std::string &>(), py::arg("model"))
|
||||
.def_readwrite("model", &PyClass::model)
|
||||
.def("__str__", &PyClass::ToString)
|
||||
.def("validate", &PyClass::Validate);
|
||||
}
|
||||
|
||||
static void PybindOfflineSpeakerSegmentationModelConfig(py::module *m) {
|
||||
PybindOfflineSpeakerSegmentationPyannoteModelConfig(m);
|
||||
|
||||
using PyClass = OfflineSpeakerSegmentationModelConfig;
|
||||
py::class_<PyClass>(*m, "OfflineSpeakerSegmentationModelConfig")
|
||||
.def(py::init<>())
|
||||
.def(py::init<const OfflineSpeakerSegmentationPyannoteModelConfig &,
|
||||
int32_t, bool, const std::string &>(),
|
||||
py::arg("pyannote"), py::arg("num_threads") = 1,
|
||||
py::arg("debug") = false, py::arg("provider") = "cpu")
|
||||
.def_readwrite("pyannote", &PyClass::pyannote)
|
||||
.def_readwrite("num_threads", &PyClass::num_threads)
|
||||
.def_readwrite("debug", &PyClass::debug)
|
||||
.def_readwrite("provider", &PyClass::provider)
|
||||
.def("__str__", &PyClass::ToString)
|
||||
.def("validate", &PyClass::Validate);
|
||||
}
|
||||
|
||||
static void PybindOfflineSpeakerDiarizationConfig(py::module *m) {
|
||||
PybindOfflineSpeakerSegmentationModelConfig(m);
|
||||
|
||||
using PyClass = OfflineSpeakerDiarizationConfig;
|
||||
py::class_<PyClass>(*m, "OfflineSpeakerDiarizationConfig")
|
||||
.def(py::init<const OfflineSpeakerSegmentationModelConfig &,
|
||||
const SpeakerEmbeddingExtractorConfig &,
|
||||
const FastClusteringConfig &, float, float>(),
|
||||
py::arg("segmentation"), py::arg("embedding"), py::arg("clustering"),
|
||||
py::arg("min_duration_on") = 0.3, py::arg("min_duration_off") = 0.5)
|
||||
.def_readwrite("segmentation", &PyClass::segmentation)
|
||||
.def_readwrite("embedding", &PyClass::embedding)
|
||||
.def_readwrite("clustering", &PyClass::clustering)
|
||||
.def_readwrite("min_duration_on", &PyClass::min_duration_on)
|
||||
.def_readwrite("min_duration_off", &PyClass::min_duration_off)
|
||||
.def("__str__", &PyClass::ToString)
|
||||
.def("validate", &PyClass::Validate);
|
||||
}
|
||||
|
||||
void PybindOfflineSpeakerDiarization(py::module *m) {
|
||||
PybindOfflineSpeakerDiarizationConfig(m);
|
||||
|
||||
using PyClass = OfflineSpeakerDiarization;
|
||||
py::class_<PyClass>(*m, "OfflineSpeakerDiarization")
|
||||
.def(py::init<const OfflineSpeakerDiarizationConfig &>(),
|
||||
py::arg("config"))
|
||||
.def_property_readonly("sample_rate", &PyClass::SampleRate)
|
||||
.def(
|
||||
"process",
|
||||
[](const PyClass &self, const std::vector<float> samples,
|
||||
std::function<int32_t(int32_t, int32_t)> callback) {
|
||||
if (!callback) {
|
||||
return self.Process(samples.data(), samples.size());
|
||||
}
|
||||
|
||||
std::function<int32_t(int32_t, int32_t, void *)> callback_wrapper =
|
||||
[callback](int32_t processed_chunks, int32_t num_chunks,
|
||||
void *) -> int32_t {
|
||||
callback(processed_chunks, num_chunks);
|
||||
return 0;
|
||||
};
|
||||
|
||||
return self.Process(samples.data(), samples.size(),
|
||||
callback_wrapper);
|
||||
},
|
||||
py::arg("samples"), py::arg("callback") = py::none());
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
16
sherpa-onnx/python/csrc/offline-speaker-diarization.h
Normal file
16
sherpa-onnx/python/csrc/offline-speaker-diarization.h
Normal file
@@ -0,0 +1,16 @@
|
||||
// sherpa-onnx/python/csrc/offline-speaker-diarization.h
|
||||
//
|
||||
// Copyright (c) 2024 Xiaomi Corporation
|
||||
|
||||
#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SPEAKER_DIARIZATION_H_
|
||||
#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SPEAKER_DIARIZATION_H_
|
||||
|
||||
#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
void PybindOfflineSpeakerDiarization(py::module *m);
|
||||
|
||||
}
|
||||
|
||||
#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SPEAKER_DIARIZATION_H_
|
||||
@@ -37,6 +37,8 @@
|
||||
|
||||
#if SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION == 1
|
||||
#include "sherpa-onnx/python/csrc/fast-clustering.h"
|
||||
#include "sherpa-onnx/python/csrc/offline-speaker-diarization-result.h"
|
||||
#include "sherpa-onnx/python/csrc/offline-speaker-diarization.h"
|
||||
#endif
|
||||
|
||||
namespace sherpa_onnx {
|
||||
@@ -74,14 +76,16 @@ PYBIND11_MODULE(_sherpa_onnx, m) {
|
||||
PybindOfflineTts(&m);
|
||||
#endif
|
||||
|
||||
#if SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION == 1
|
||||
PybindFastClustering(&m);
|
||||
#endif
|
||||
|
||||
PybindSpeakerEmbeddingExtractor(&m);
|
||||
PybindSpeakerEmbeddingManager(&m);
|
||||
PybindSpokenLanguageIdentification(&m);
|
||||
|
||||
#if SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION == 1
|
||||
PybindFastClustering(&m);
|
||||
PybindOfflineSpeakerDiarizationResult(&m);
|
||||
PybindOfflineSpeakerDiarization(&m);
|
||||
#endif
|
||||
|
||||
PybindAlsa(&m);
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user