88 lines
2.3 KiB
Python
Executable File
88 lines
2.3 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
|
|
|
|
"""
|
|
Please refer to
|
|
https://github.com/k2-fsa/sherpa-onnx/blob/master/.github/workflows/speaker-diarization.yaml
|
|
for usages.
|
|
"""
|
|
|
|
"""
|
|
1. Go to https://huggingface.co/hbredin/wespeaker-voxceleb-resnet34-LM/tree/main
|
|
wget https://huggingface.co/hbredin/wespeaker-voxceleb-resnet34-LM/resolve/main/speaker-embedding.onnx
|
|
|
|
2. Change line 166 of pyannote/audio/pipelines/speaker_diarization.py
|
|
|
|
```
|
|
# self._embedding = PretrainedSpeakerEmbedding(
|
|
# self.embedding, use_auth_token=use_auth_token
|
|
# )
|
|
self._embedding = embedding
|
|
```
|
|
"""
|
|
|
|
import argparse
|
|
from pathlib import Path
|
|
|
|
import torch
|
|
from pyannote.audio import Model
|
|
from pyannote.audio.pipelines import SpeakerDiarization as SpeakerDiarizationPipeline
|
|
from pyannote.audio.pipelines.speaker_verification import (
|
|
ONNXWeSpeakerPretrainedSpeakerEmbedding,
|
|
)
|
|
|
|
|
|
def get_args():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--wav", type=str, required=True, help="Path to test.wav")
|
|
|
|
return parser.parse_args()
|
|
|
|
|
|
def build_pipeline():
|
|
embedding_filename = "./speaker-embedding.onnx"
|
|
if Path(embedding_filename).is_file():
|
|
# You need to modify line 166
|
|
# of pyannote/audio/pipelines/speaker_diarization.py
|
|
# Please see the comments at the start of this script for details
|
|
embedding = ONNXWeSpeakerPretrainedSpeakerEmbedding(embedding_filename)
|
|
else:
|
|
embedding = "hbredin/wespeaker-voxceleb-resnet34-LM"
|
|
|
|
pt_filename = "./pytorch_model.bin"
|
|
segmentation = Model.from_pretrained(pt_filename)
|
|
segmentation.eval()
|
|
|
|
pipeline = SpeakerDiarizationPipeline(
|
|
segmentation=segmentation,
|
|
embedding=embedding,
|
|
embedding_exclude_overlap=True,
|
|
)
|
|
|
|
params = {
|
|
"clustering": {
|
|
"method": "centroid",
|
|
"min_cluster_size": 12,
|
|
"threshold": 0.7045654963945799,
|
|
},
|
|
"segmentation": {"min_duration_off": 0.5},
|
|
}
|
|
|
|
pipeline.instantiate(params)
|
|
return pipeline
|
|
|
|
|
|
@torch.no_grad()
|
|
def main():
|
|
args = get_args()
|
|
assert Path(args.wav).is_file(), args.wav
|
|
pipeline = build_pipeline()
|
|
print(pipeline)
|
|
t = pipeline(args.wav)
|
|
print(type(t))
|
|
print(t)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|