This repository has been archived on 2025-08-26. You can view files and clone it, but cannot push or open issues or pull requests.
Files
enginex-mr_series-sherpa-onnx/python-api-examples/offline-speaker-diarization.py
2024-10-28 21:30:18 +08:00

119 lines
4.0 KiB
Python
Executable File

#!/usr/bin/env python3
# Copyright (c) 2024 Xiaomi Corporation
"""
This file shows how to use sherpa-onnx Python API for
offline/non-streaming speaker diarization.
Usage:
Step 1: Download a speaker segmentation model
Please visit https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-segmentation-models
for a list of available models. The following is an example
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/sherpa-onnx-pyannote-segmentation-3-0.tar.bz2
tar xvf sherpa-onnx-pyannote-segmentation-3-0.tar.bz2
rm sherpa-onnx-pyannote-segmentation-3-0.tar.bz2
Step 2: Download a speaker embedding extractor model
Please visit https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-recongition-models
for a list of available models. The following is an example
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx
Step 3. Download test wave files
Please visit https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-segmentation-models
for a list of available test wave files. The following is an example
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/0-four-speakers-zh.wav
Step 4. Run it
python3 ./python-api-examples/offline-speaker-diarization.py
"""
from pathlib import Path
import sherpa_onnx
import soundfile as sf
def init_speaker_diarization(num_speakers: int = -1, cluster_threshold: float = 0.5):
"""
Args:
num_speakers:
If you know the actual number of speakers in the wave file, then please
specify it. Otherwise, leave it to -1
cluster_threshold:
If num_speakers is -1, then this threshold is used for clustering.
A smaller cluster_threshold leads to more clusters, i.e., more speakers.
A larger cluster_threshold leads to fewer clusters, i.e., fewer speakers.
"""
segmentation_model = "./sherpa-onnx-pyannote-segmentation-3-0/model.onnx"
embedding_extractor_model = (
"./3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx"
)
config = sherpa_onnx.OfflineSpeakerDiarizationConfig(
segmentation=sherpa_onnx.OfflineSpeakerSegmentationModelConfig(
pyannote=sherpa_onnx.OfflineSpeakerSegmentationPyannoteModelConfig(
model=segmentation_model
),
),
embedding=sherpa_onnx.SpeakerEmbeddingExtractorConfig(
model=embedding_extractor_model
),
clustering=sherpa_onnx.FastClusteringConfig(
num_clusters=num_speakers, threshold=cluster_threshold
),
min_duration_on=0.3,
min_duration_off=0.5,
)
if not config.validate():
raise RuntimeError(
"Please check your config and make sure all required files exist"
)
return sherpa_onnx.OfflineSpeakerDiarization(config)
def progress_callback(num_processed_chunk: int, num_total_chunks: int) -> int:
progress = num_processed_chunk / num_total_chunks * 100
print(f"Progress: {progress:.3f}%")
return 0
def main():
wave_filename = "./0-four-speakers-zh.wav"
if not Path(wave_filename).is_file():
raise RuntimeError(f"{wave_filename} does not exist")
audio, sample_rate = sf.read(wave_filename, dtype="float32", always_2d=True)
audio = audio[:, 0] # only use the first channel
# Since we know there are 4 speakers in the above test wave file, we use
# num_speakers 4 here
sd = init_speaker_diarization(num_speakers=4)
if sample_rate != sd.sample_rate:
raise RuntimeError(
f"Expected samples rate: {sd.sample_rate}, given: {sample_rate}"
)
show_progress = True
if show_progress:
result = sd.process(audio, callback=progress_callback).sort_by_start_time()
else:
result = sd.process(audio).sort_by_start_time()
for r in result:
print(f"{r.start:.3f} -- {r.end:.3f} speaker_{r.speaker:02}")
# print(r) # this one is simpler
if __name__ == "__main__":
main()