Python API for speaker diarization. (#1400)
This commit is contained in:
15
.github/scripts/test-python.sh
vendored
15
.github/scripts/test-python.sh
vendored
@@ -8,6 +8,21 @@ log() {
|
||||
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
||||
}
|
||||
|
||||
log "test offline speaker diarization"
|
||||
|
||||
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/sherpa-onnx-pyannote-segmentation-3-0.tar.bz2
|
||||
tar xvf sherpa-onnx-pyannote-segmentation-3-0.tar.bz2
|
||||
rm sherpa-onnx-pyannote-segmentation-3-0.tar.bz2
|
||||
|
||||
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx
|
||||
|
||||
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/0-four-speakers-zh.wav
|
||||
|
||||
python3 ./python-api-examples/offline-speaker-diarization.py
|
||||
|
||||
rm -rf *.wav *.onnx ./sherpa-onnx-pyannote-segmentation-3-0
|
||||
|
||||
|
||||
log "test_clustering"
|
||||
pushd /tmp/
|
||||
mkdir test-cluster
|
||||
|
||||
2
.github/workflows/windows-x64.yaml
vendored
2
.github/workflows/windows-x64.yaml
vendored
@@ -93,7 +93,7 @@ jobs:
|
||||
shell: bash
|
||||
run: |
|
||||
du -h -d1 .
|
||||
export PATH=$PWD/build/bin:$PATH
|
||||
export PATH=$PWD/build/bin/Release:$PATH
|
||||
export EXE=sherpa-onnx-offline-speaker-diarization.exe
|
||||
|
||||
.github/scripts/test-speaker-diarization.sh
|
||||
|
||||
2
.github/workflows/windows-x86.yaml
vendored
2
.github/workflows/windows-x86.yaml
vendored
@@ -93,7 +93,7 @@ jobs:
|
||||
shell: bash
|
||||
run: |
|
||||
du -h -d1 .
|
||||
export PATH=$PWD/build/bin:$PATH
|
||||
export PATH=$PWD/build/bin/Release:$PATH
|
||||
export EXE=sherpa-onnx-offline-speaker-diarization.exe
|
||||
|
||||
.github/scripts/test-speaker-diarization.sh
|
||||
|
||||
118
python-api-examples/offline-speaker-diarization.py
Executable file
118
python-api-examples/offline-speaker-diarization.py
Executable file
@@ -0,0 +1,118 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) 2024 Xiaomi Corporation
|
||||
|
||||
"""
|
||||
This file shows how to use sherpa-onnx Python API for
|
||||
offline/non-streaming speaker diarization.
|
||||
|
||||
Usage:
|
||||
|
||||
Step 1: Download a speaker segmentation model
|
||||
|
||||
Please visit https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-segmentation-models
|
||||
for a list of available models. The following is an example
|
||||
|
||||
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/sherpa-onnx-pyannote-segmentation-3-0.tar.bz2
|
||||
tar xvf sherpa-onnx-pyannote-segmentation-3-0.tar.bz2
|
||||
rm sherpa-onnx-pyannote-segmentation-3-0.tar.bz2
|
||||
|
||||
Step 2: Download a speaker embedding extractor model
|
||||
|
||||
Please visit https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-recongition-models
|
||||
for a list of available models. The following is an example
|
||||
|
||||
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx
|
||||
|
||||
Step 3. Download test wave files
|
||||
|
||||
Please visit https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-segmentation-models
|
||||
for a list of available test wave files. The following is an example
|
||||
|
||||
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/0-four-speakers-zh.wav
|
||||
|
||||
Step 4. Run it
|
||||
|
||||
python3 ./python-api-examples/offline-speaker-diarization.py
|
||||
|
||||
"""
|
||||
from pathlib import Path
|
||||
|
||||
import sherpa_onnx
|
||||
import soundfile as sf
|
||||
|
||||
|
||||
def init_speaker_diarization(num_speakers: int = -1, cluster_threshold: float = 0.5):
|
||||
"""
|
||||
Args:
|
||||
num_speakers:
|
||||
If you know the actual number of speakers in the wave file, then please
|
||||
specify it. Otherwise, leave it to -1
|
||||
cluster_threshold:
|
||||
If num_speakers is -1, then this threshold is used for clustering.
|
||||
A smaller cluster_threshold leads to more clusters, i.e., more speakers.
|
||||
A larger cluster_threshold leads to fewer clusters, i.e., fewer speakers.
|
||||
"""
|
||||
segmentation_model = "./sherpa-onnx-pyannote-segmentation-3-0/model.onnx"
|
||||
embedding_extractor_model = (
|
||||
"./3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx"
|
||||
)
|
||||
|
||||
config = sherpa_onnx.OfflineSpeakerDiarizationConfig(
|
||||
segmentation=sherpa_onnx.OfflineSpeakerSegmentationModelConfig(
|
||||
pyannote=sherpa_onnx.OfflineSpeakerSegmentationPyannoteModelConfig(
|
||||
model=segmentation_model
|
||||
),
|
||||
),
|
||||
embedding=sherpa_onnx.SpeakerEmbeddingExtractorConfig(
|
||||
model=embedding_extractor_model
|
||||
),
|
||||
clustering=sherpa_onnx.FastClusteringConfig(
|
||||
num_clusters=num_speakers, threshold=cluster_threshold
|
||||
),
|
||||
min_duration_on=0.3,
|
||||
min_duration_off=0.5,
|
||||
)
|
||||
if not config.validate():
|
||||
raise RuntimeError(
|
||||
"Please check your config and make sure all required files exist"
|
||||
)
|
||||
|
||||
return sherpa_onnx.OfflineSpeakerDiarization(config)
|
||||
|
||||
|
||||
def progress_callback(num_processed_chunk: int, num_total_chunks: int) -> int:
|
||||
progress = num_processed_chunk / num_total_chunks * 100
|
||||
print(f"Progress: {progress:.3f}%")
|
||||
return 0
|
||||
|
||||
|
||||
def main():
|
||||
wave_filename = "./0-four-speakers-zh.wav"
|
||||
if not Path(wave_filename).is_file():
|
||||
raise RuntimeError(f"{wave_filename} does not exist")
|
||||
|
||||
audio, sample_rate = sf.read(wave_filename, dtype="float32", always_2d=True)
|
||||
audio = audio[:, 0] # only use the first channel
|
||||
|
||||
# Since we know there are 4 speakers in the above test wave file, we use
|
||||
# num_speakers 4 here
|
||||
sd = init_speaker_diarization(num_speakers=4)
|
||||
if sample_rate != sd.sample_rate:
|
||||
raise RuntimeError(
|
||||
f"Expected samples rate: {sd.sample_rate}, given: {sample_rate}"
|
||||
)
|
||||
|
||||
show_porgress = True
|
||||
|
||||
if show_porgress:
|
||||
result = sd.process(audio, callback=progress_callback).sort_by_start_time()
|
||||
else:
|
||||
result = sd.process(audio).sort_by_start_time()
|
||||
|
||||
for r in result:
|
||||
print(f"{r.start:.3f} -- {r.end:.3f} speaker_{r.speaker:02}")
|
||||
# print(r) # this one is simpler
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -103,7 +103,7 @@ class OfflineSpeakerDiarizationPyannoteImpl
|
||||
auto chunk_speaker_samples_list_pair = GetChunkSpeakerSampleIndexes(labels);
|
||||
Matrix2D embeddings =
|
||||
ComputeEmbeddings(audio, n, chunk_speaker_samples_list_pair.second,
|
||||
callback, callback_arg);
|
||||
std::move(callback), callback_arg);
|
||||
|
||||
std::vector<int32_t> cluster_labels = clustering_.Cluster(
|
||||
&embeddings(0, 0), embeddings.rows(), embeddings.cols());
|
||||
|
||||
@@ -28,6 +28,8 @@ class OfflineSpeakerDiarizationSegment {
|
||||
const std::string &Text() const { return text_; }
|
||||
float Duration() const { return end_ - start_; }
|
||||
|
||||
void SetText(const std::string &text) { text_ = text; }
|
||||
|
||||
std::string ToString() const;
|
||||
|
||||
private:
|
||||
|
||||
@@ -34,10 +34,13 @@ struct OfflineSpeakerDiarizationConfig {
|
||||
OfflineSpeakerDiarizationConfig(
|
||||
const OfflineSpeakerSegmentationModelConfig &segmentation,
|
||||
const SpeakerEmbeddingExtractorConfig &embedding,
|
||||
const FastClusteringConfig &clustering)
|
||||
const FastClusteringConfig &clustering, float min_duration_on,
|
||||
float min_duration_off)
|
||||
: segmentation(segmentation),
|
||||
embedding(embedding),
|
||||
clustering(clustering) {}
|
||||
clustering(clustering),
|
||||
min_duration_on(min_duration_on),
|
||||
min_duration_off(min_duration_off) {}
|
||||
|
||||
void Register(ParseOptions *po);
|
||||
bool Validate() const;
|
||||
|
||||
@@ -62,6 +62,8 @@ endif()
|
||||
if(SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION)
|
||||
list(APPEND srcs
|
||||
fast-clustering.cc
|
||||
offline-speaker-diarization-result.cc
|
||||
offline-speaker-diarization.cc
|
||||
)
|
||||
endif()
|
||||
|
||||
|
||||
@@ -0,0 +1,32 @@
|
||||
// sherpa-onnx/python/csrc/offline-speaker-diarization-result.cc
|
||||
//
|
||||
// Copyright (c) 2024 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/python/csrc/offline-speaker-diarization-result.h"
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-speaker-diarization-result.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
static void PybindOfflineSpeakerDiarizationSegment(py::module *m) {
|
||||
using PyClass = OfflineSpeakerDiarizationSegment;
|
||||
py::class_<PyClass>(*m, "OfflineSpeakerDiarizationSegment")
|
||||
.def_property_readonly("start", &PyClass::Start)
|
||||
.def_property_readonly("end", &PyClass::End)
|
||||
.def_property_readonly("duration", &PyClass::Duration)
|
||||
.def_property_readonly("speaker", &PyClass::Speaker)
|
||||
.def_property("text", &PyClass::Text, &PyClass::SetText)
|
||||
.def("__str__", &PyClass::ToString);
|
||||
}
|
||||
|
||||
void PybindOfflineSpeakerDiarizationResult(py::module *m) {
|
||||
PybindOfflineSpeakerDiarizationSegment(m);
|
||||
using PyClass = OfflineSpeakerDiarizationResult;
|
||||
py::class_<PyClass>(*m, "OfflineSpeakerDiarizationResult")
|
||||
.def_property_readonly("num_speakers", &PyClass::NumSpeakers)
|
||||
.def_property_readonly("num_segments", &PyClass::NumSegments)
|
||||
.def("sort_by_start_time", &PyClass::SortByStartTime)
|
||||
.def("sort_by_speaker", &PyClass::SortBySpeaker);
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
16
sherpa-onnx/python/csrc/offline-speaker-diarization-result.h
Normal file
16
sherpa-onnx/python/csrc/offline-speaker-diarization-result.h
Normal file
@@ -0,0 +1,16 @@
|
||||
// sherpa-onnx/python/csrc/offline-speaker-diarization-result.h
|
||||
//
|
||||
// Copyright (c) 2024 Xiaomi Corporation
|
||||
|
||||
#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SPEAKER_DIARIZATION_RESULT_H_
|
||||
#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SPEAKER_DIARIZATION_RESULT_H_
|
||||
|
||||
#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
void PybindOfflineSpeakerDiarizationResult(py::module *m);
|
||||
|
||||
}
|
||||
|
||||
#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SPEAKER_DIARIZATION_RESULT_H_
|
||||
92
sherpa-onnx/python/csrc/offline-speaker-diarization.cc
Normal file
92
sherpa-onnx/python/csrc/offline-speaker-diarization.cc
Normal file
@@ -0,0 +1,92 @@
|
||||
// sherpa-onnx/python/csrc/offline-speaker-diarization.cc
|
||||
//
|
||||
// Copyright (c) 2024 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/python/csrc/offline-speaker-diarization.h"
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-speaker-diarization.h"
|
||||
#include "sherpa-onnx/csrc/offline-speaker-segmentation-model-config.h"
|
||||
#include "sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model-config.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
static void PybindOfflineSpeakerSegmentationPyannoteModelConfig(py::module *m) {
|
||||
using PyClass = OfflineSpeakerSegmentationPyannoteModelConfig;
|
||||
py::class_<PyClass>(*m, "OfflineSpeakerSegmentationPyannoteModelConfig")
|
||||
.def(py::init<>())
|
||||
.def(py::init<const std::string &>(), py::arg("model"))
|
||||
.def_readwrite("model", &PyClass::model)
|
||||
.def("__str__", &PyClass::ToString)
|
||||
.def("validate", &PyClass::Validate);
|
||||
}
|
||||
|
||||
static void PybindOfflineSpeakerSegmentationModelConfig(py::module *m) {
|
||||
PybindOfflineSpeakerSegmentationPyannoteModelConfig(m);
|
||||
|
||||
using PyClass = OfflineSpeakerSegmentationModelConfig;
|
||||
py::class_<PyClass>(*m, "OfflineSpeakerSegmentationModelConfig")
|
||||
.def(py::init<>())
|
||||
.def(py::init<const OfflineSpeakerSegmentationPyannoteModelConfig &,
|
||||
int32_t, bool, const std::string &>(),
|
||||
py::arg("pyannote"), py::arg("num_threads") = 1,
|
||||
py::arg("debug") = false, py::arg("provider") = "cpu")
|
||||
.def_readwrite("pyannote", &PyClass::pyannote)
|
||||
.def_readwrite("num_threads", &PyClass::num_threads)
|
||||
.def_readwrite("debug", &PyClass::debug)
|
||||
.def_readwrite("provider", &PyClass::provider)
|
||||
.def("__str__", &PyClass::ToString)
|
||||
.def("validate", &PyClass::Validate);
|
||||
}
|
||||
|
||||
static void PybindOfflineSpeakerDiarizationConfig(py::module *m) {
|
||||
PybindOfflineSpeakerSegmentationModelConfig(m);
|
||||
|
||||
using PyClass = OfflineSpeakerDiarizationConfig;
|
||||
py::class_<PyClass>(*m, "OfflineSpeakerDiarizationConfig")
|
||||
.def(py::init<const OfflineSpeakerSegmentationModelConfig &,
|
||||
const SpeakerEmbeddingExtractorConfig &,
|
||||
const FastClusteringConfig &, float, float>(),
|
||||
py::arg("segmentation"), py::arg("embedding"), py::arg("clustering"),
|
||||
py::arg("min_duration_on") = 0.3, py::arg("min_duration_off") = 0.5)
|
||||
.def_readwrite("segmentation", &PyClass::segmentation)
|
||||
.def_readwrite("embedding", &PyClass::embedding)
|
||||
.def_readwrite("clustering", &PyClass::clustering)
|
||||
.def_readwrite("min_duration_on", &PyClass::min_duration_on)
|
||||
.def_readwrite("min_duration_off", &PyClass::min_duration_off)
|
||||
.def("__str__", &PyClass::ToString)
|
||||
.def("validate", &PyClass::Validate);
|
||||
}
|
||||
|
||||
void PybindOfflineSpeakerDiarization(py::module *m) {
|
||||
PybindOfflineSpeakerDiarizationConfig(m);
|
||||
|
||||
using PyClass = OfflineSpeakerDiarization;
|
||||
py::class_<PyClass>(*m, "OfflineSpeakerDiarization")
|
||||
.def(py::init<const OfflineSpeakerDiarizationConfig &>(),
|
||||
py::arg("config"))
|
||||
.def_property_readonly("sample_rate", &PyClass::SampleRate)
|
||||
.def(
|
||||
"process",
|
||||
[](const PyClass &self, const std::vector<float> samples,
|
||||
std::function<int32_t(int32_t, int32_t)> callback) {
|
||||
if (!callback) {
|
||||
return self.Process(samples.data(), samples.size());
|
||||
}
|
||||
|
||||
std::function<int32_t(int32_t, int32_t, void *)> callback_wrapper =
|
||||
[callback](int32_t processed_chunks, int32_t num_chunks,
|
||||
void *) -> int32_t {
|
||||
callback(processed_chunks, num_chunks);
|
||||
return 0;
|
||||
};
|
||||
|
||||
return self.Process(samples.data(), samples.size(),
|
||||
callback_wrapper);
|
||||
},
|
||||
py::arg("samples"), py::arg("callback") = py::none());
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
16
sherpa-onnx/python/csrc/offline-speaker-diarization.h
Normal file
16
sherpa-onnx/python/csrc/offline-speaker-diarization.h
Normal file
@@ -0,0 +1,16 @@
|
||||
// sherpa-onnx/python/csrc/offline-speaker-diarization.h
|
||||
//
|
||||
// Copyright (c) 2024 Xiaomi Corporation
|
||||
|
||||
#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SPEAKER_DIARIZATION_H_
|
||||
#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SPEAKER_DIARIZATION_H_
|
||||
|
||||
#include "sherpa-onnx/python/csrc/sherpa-onnx.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
void PybindOfflineSpeakerDiarization(py::module *m);
|
||||
|
||||
}
|
||||
|
||||
#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_SPEAKER_DIARIZATION_H_
|
||||
@@ -37,6 +37,8 @@
|
||||
|
||||
#if SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION == 1
|
||||
#include "sherpa-onnx/python/csrc/fast-clustering.h"
|
||||
#include "sherpa-onnx/python/csrc/offline-speaker-diarization-result.h"
|
||||
#include "sherpa-onnx/python/csrc/offline-speaker-diarization.h"
|
||||
#endif
|
||||
|
||||
namespace sherpa_onnx {
|
||||
@@ -74,14 +76,16 @@ PYBIND11_MODULE(_sherpa_onnx, m) {
|
||||
PybindOfflineTts(&m);
|
||||
#endif
|
||||
|
||||
#if SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION == 1
|
||||
PybindFastClustering(&m);
|
||||
#endif
|
||||
|
||||
PybindSpeakerEmbeddingExtractor(&m);
|
||||
PybindSpeakerEmbeddingManager(&m);
|
||||
PybindSpokenLanguageIdentification(&m);
|
||||
|
||||
#if SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION == 1
|
||||
PybindFastClustering(&m);
|
||||
PybindOfflineSpeakerDiarizationResult(&m);
|
||||
PybindOfflineSpeakerDiarization(&m);
|
||||
#endif
|
||||
|
||||
PybindAlsa(&m);
|
||||
}
|
||||
|
||||
|
||||
@@ -11,6 +11,12 @@ from _sherpa_onnx import (
|
||||
OfflinePunctuation,
|
||||
OfflinePunctuationConfig,
|
||||
OfflinePunctuationModelConfig,
|
||||
OfflineSpeakerDiarization,
|
||||
OfflineSpeakerDiarizationConfig,
|
||||
OfflineSpeakerDiarizationResult,
|
||||
OfflineSpeakerDiarizationSegment,
|
||||
OfflineSpeakerSegmentationModelConfig,
|
||||
OfflineSpeakerSegmentationPyannoteModelConfig,
|
||||
OfflineStream,
|
||||
OfflineTts,
|
||||
OfflineTtsConfig,
|
||||
|
||||
Reference in New Issue
Block a user