Speaker diarization example with onnxruntime Python API (#1395)
This commit is contained in:
86
scripts/pyannote/segmentation/speaker-diarization-torch.py
Executable file
86
scripts/pyannote/segmentation/speaker-diarization-torch.py
Executable file
@@ -0,0 +1,86 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
"""
|
||||
Please refer to
|
||||
https://github.com/k2-fsa/sherpa-onnx/blob/master/.github/workflows/speaker-diarization.yaml
|
||||
for usages.
|
||||
"""
|
||||
|
||||
"""
|
||||
1. Go to https://huggingface.co/hbredin/wespeaker-voxceleb-resnet34-LM/tree/main
|
||||
wget https://huggingface.co/hbredin/wespeaker-voxceleb-resnet34-LM/resolve/main/speaker-embedding.onnx
|
||||
|
||||
2. Change line 166 of pyannote/audio/pipelines/speaker_diarization.py
|
||||
|
||||
```
|
||||
# self._embedding = PretrainedSpeakerEmbedding(
|
||||
# self.embedding, use_auth_token=use_auth_token
|
||||
# )
|
||||
self._embedding = embedding
|
||||
```
|
||||
"""
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from pyannote.audio import Model
|
||||
from pyannote.audio.pipelines import SpeakerDiarization as SpeakerDiarizationPipeline
|
||||
from pyannote.audio.pipelines.speaker_verification import (
|
||||
ONNXWeSpeakerPretrainedSpeakerEmbedding,
|
||||
)
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--wav", type=str, required=True, help="Path to test.wav")
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def build_pipeline():
|
||||
embedding_filename = "./speaker-embedding.onnx"
|
||||
if Path(embedding_filename).is_file():
|
||||
# You need to modify line 166
|
||||
# of pyannote/audio/pipelines/speaker_diarization.py
|
||||
# Please see the comments at the start of this script for details
|
||||
embedding = ONNXWeSpeakerPretrainedSpeakerEmbedding(embedding_filename)
|
||||
else:
|
||||
embedding = "hbredin/wespeaker-voxceleb-resnet34-LM"
|
||||
|
||||
pt_filename = "./pytorch_model.bin"
|
||||
segmentation = Model.from_pretrained(pt_filename)
|
||||
segmentation.eval()
|
||||
|
||||
pipeline = SpeakerDiarizationPipeline(
|
||||
segmentation=segmentation,
|
||||
embedding=embedding,
|
||||
embedding_exclude_overlap=True,
|
||||
)
|
||||
|
||||
params = {
|
||||
"clustering": {
|
||||
"method": "centroid",
|
||||
"min_cluster_size": 12,
|
||||
"threshold": 0.7045654963945799,
|
||||
},
|
||||
"segmentation": {"min_duration_off": 0.5},
|
||||
}
|
||||
|
||||
pipeline.instantiate(params)
|
||||
return pipeline
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
args = get_args()
|
||||
assert Path(args.wav).is_file(), args.wav
|
||||
pipeline = build_pipeline()
|
||||
print(pipeline)
|
||||
t = pipeline(args.wav)
|
||||
print(type(t))
|
||||
print(t)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user