Speaker diarization example with onnxruntime Python API (#1395)
This commit is contained in:
98
.github/workflows/speaker-diarization.yaml
vendored
Normal file
98
.github/workflows/speaker-diarization.yaml
vendored
Normal file
@@ -0,0 +1,98 @@
|
|||||||
|
name: speaker-diarization
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches:
|
||||||
|
- speaker-diarization
|
||||||
|
workflow_dispatch:
|
||||||
|
|
||||||
|
concurrency:
|
||||||
|
group: speaker-diarization-${{ github.ref }}
|
||||||
|
cancel-in-progress: true
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
linux:
|
||||||
|
name: speaker diarization
|
||||||
|
runs-on: ${{ matrix.os }}
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
matrix:
|
||||||
|
os: [macos-latest]
|
||||||
|
python-version: ["3.10"]
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
fetch-depth: 0
|
||||||
|
|
||||||
|
- name: ccache
|
||||||
|
uses: hendrikmuhs/ccache-action@v1.2
|
||||||
|
with:
|
||||||
|
key: ${{ matrix.os }}-speaker-diarization
|
||||||
|
|
||||||
|
- name: Setup Python ${{ matrix.python-version }}
|
||||||
|
uses: actions/setup-python@v5
|
||||||
|
with:
|
||||||
|
python-version: ${{ matrix.python-version }}
|
||||||
|
|
||||||
|
- name: Install pyannote
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
pip install pyannote.audio onnx onnxruntime
|
||||||
|
|
||||||
|
- name: Install sherpa-onnx from source
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
python3 -m pip install --upgrade pip
|
||||||
|
python3 -m pip install wheel twine setuptools
|
||||||
|
|
||||||
|
export CMAKE_CXX_COMPILER_LAUNCHER=ccache
|
||||||
|
export PATH="/usr/lib/ccache:/usr/local/opt/ccache/libexec:$PATH"
|
||||||
|
|
||||||
|
cat sherpa-onnx/python/sherpa_onnx/__init__.py
|
||||||
|
|
||||||
|
python3 setup.py bdist_wheel
|
||||||
|
ls -lh dist
|
||||||
|
pip install ./dist/*.whl
|
||||||
|
|
||||||
|
- name: Run tests
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
pushd scripts/pyannote/segmentation
|
||||||
|
|
||||||
|
python3 -c "import sherpa_onnx; print(sherpa_onnx.__file__)"
|
||||||
|
python3 -c "import sherpa_onnx; print(sherpa_onnx.__version__)"
|
||||||
|
python3 -c "import sherpa_onnx; print(dir(sherpa_onnx))"
|
||||||
|
|
||||||
|
curl -SL -O https://huggingface.co/csukuangfj/pyannote-models/resolve/main/segmentation-3.0/pytorch_model.bin
|
||||||
|
|
||||||
|
test_wavs=(
|
||||||
|
0-two-speakers-zh.wav
|
||||||
|
1-two-speakers-en.wav
|
||||||
|
2-two-speakers-en.wav
|
||||||
|
3-two-speakers-en.wav
|
||||||
|
)
|
||||||
|
|
||||||
|
for w in ${test_wavs[@]}; do
|
||||||
|
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/$w
|
||||||
|
done
|
||||||
|
|
||||||
|
soxi *.wav
|
||||||
|
|
||||||
|
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/sherpa-onnx-pyannote-segmentation-3-0.tar.bz2
|
||||||
|
tar xvf sherpa-onnx-pyannote-segmentation-3-0.tar.bz2
|
||||||
|
rm sherpa-onnx-pyannote-segmentation-3-0.tar.bz2
|
||||||
|
ls -lh sherpa-onnx-pyannote-segmentation-3-0
|
||||||
|
|
||||||
|
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx
|
||||||
|
|
||||||
|
for w in ${test_wavs[@]}; do
|
||||||
|
echo "---------test $w (onnx)----------"
|
||||||
|
time ./speaker-diarization-onnx.py \
|
||||||
|
--seg-model ./sherpa-onnx-pyannote-segmentation-3-0/model.onnx \
|
||||||
|
--speaker-embedding-model ./3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx \
|
||||||
|
--wav $w
|
||||||
|
|
||||||
|
echo "---------test $w (torch)----------"
|
||||||
|
time ./speaker-diarization-torch.py --wav $w
|
||||||
|
done
|
||||||
2
.gitignore
vendored
2
.gitignore
vendored
@@ -118,3 +118,5 @@ vits-melo-tts-zh_en
|
|||||||
*.o
|
*.o
|
||||||
*.ppu
|
*.ppu
|
||||||
sherpa-onnx-online-punct-en-2024-08-06
|
sherpa-onnx-online-punct-en-2024-08-06
|
||||||
|
*.mp4
|
||||||
|
*.mp3
|
||||||
|
|||||||
44
scripts/pyannote/segmentation/README.md
Normal file
44
scripts/pyannote/segmentation/README.md
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
# File description
|
||||||
|
|
||||||
|
Please download test wave files from
|
||||||
|
https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-segmentation-models
|
||||||
|
|
||||||
|
## 0-two-speakers-zh.wav
|
||||||
|
|
||||||
|
This file is from
|
||||||
|
https://www.modelscope.cn/models/iic/speech_campplus_speaker-diarization_common/file/view/master?fileName=examples%252F2speakers_example.wav&status=0
|
||||||
|
|
||||||
|
Note that we have renamed it from `2speakers_example.wav` to `0-two-speakers-zh.wav`.
|
||||||
|
|
||||||
|
## 1-two-speakers-en.wav
|
||||||
|
|
||||||
|
This file is from
|
||||||
|
https://github.com/pengzhendong/pyannote-onnx/blob/master/data/test_16k.wav
|
||||||
|
and it contains speeches from two speakers.
|
||||||
|
|
||||||
|
Note that we have renamed it from `test_16k.wav` to `1-two-speakers-en.wav`
|
||||||
|
|
||||||
|
|
||||||
|
## 2-two-speakers-en.wav
|
||||||
|
This file is from
|
||||||
|
https://huggingface.co/spaces/Xenova/whisper-speaker-diarization
|
||||||
|
|
||||||
|
Note that the original file is `./fcf059e3-689f-47ec-a000-bdace87f0113.mp4`.
|
||||||
|
We use the following commands to convert it to `2-two-speakers-en.wav`.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
ffmpeg -i ./fcf059e3-689f-47ec-a000-bdace87f0113.mp4 -ac 1 -ar 16000 ./2-two-speakers-en.wav
|
||||||
|
```
|
||||||
|
|
||||||
|
## 3-two-speakers-en.wav
|
||||||
|
|
||||||
|
This file is from
|
||||||
|
https://aws.amazon.com/blogs/machine-learning/deploy-a-hugging-face-pyannote-speaker-diarization-model-on-amazon-sagemaker-as-an-asynchronous-endpoint/
|
||||||
|
|
||||||
|
Note that the original file is `ML16091-Audio.mp3`. We use the following
|
||||||
|
commands to convert it to `3-two-speakers-en.wav`
|
||||||
|
|
||||||
|
|
||||||
|
```bash
|
||||||
|
sox ML16091-Audio.mp3 3-two-speakers-en.wav
|
||||||
|
```
|
||||||
488
scripts/pyannote/segmentation/speaker-diarization-onnx.py
Executable file
488
scripts/pyannote/segmentation/speaker-diarization-onnx.py
Executable file
@@ -0,0 +1,488 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||||
|
|
||||||
|
"""
|
||||||
|
Please refer to
|
||||||
|
https://github.com/k2-fsa/sherpa-onnx/blob/master/.github/workflows/speaker-diarization.yaml
|
||||||
|
for usages.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
from datetime import timedelta
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import librosa
|
||||||
|
import numpy as np
|
||||||
|
import onnxruntime as ort
|
||||||
|
import sherpa_onnx
|
||||||
|
import soundfile as sf
|
||||||
|
from numpy.lib.stride_tricks import as_strided
|
||||||
|
|
||||||
|
|
||||||
|
class Segment:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
start,
|
||||||
|
end,
|
||||||
|
speaker,
|
||||||
|
):
|
||||||
|
assert start < end
|
||||||
|
self.start = start
|
||||||
|
self.end = end
|
||||||
|
self.speaker = speaker
|
||||||
|
|
||||||
|
def merge(self, other, gap=0.5):
|
||||||
|
assert self.speaker == other.speaker, (self.speaker, other.speaker)
|
||||||
|
if self.end < other.start and self.end + gap >= other.start:
|
||||||
|
return Segment(start=self.start, end=other.end, speaker=self.speaker)
|
||||||
|
elif other.end < self.start and other.end + gap >= self.start:
|
||||||
|
return Segment(start=other.start, end=self.end, speaker=self.speaker)
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def duration(self):
|
||||||
|
return self.end - self.start
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
s = f"{timedelta(seconds=self.start)}"[:-3]
|
||||||
|
s += " --> "
|
||||||
|
s += f"{timedelta(seconds=self.end)}"[:-3]
|
||||||
|
s += f" speaker_{self.speaker:02d}"
|
||||||
|
return s
|
||||||
|
|
||||||
|
|
||||||
|
def merge_segment_list(in_out: List[Segment], min_duration_off: float):
|
||||||
|
changed = True
|
||||||
|
while changed:
|
||||||
|
changed = False
|
||||||
|
for i in range(len(in_out)):
|
||||||
|
if i + 1 >= len(in_out):
|
||||||
|
continue
|
||||||
|
|
||||||
|
new_segment = in_out[i].merge(in_out[i + 1], gap=min_duration_off)
|
||||||
|
if new_segment is None:
|
||||||
|
continue
|
||||||
|
del in_out[i + 1]
|
||||||
|
in_out[i] = new_segment
|
||||||
|
changed = True
|
||||||
|
break
|
||||||
|
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--seg-model",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Path to model.onnx for segmentation",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--speaker-embedding-model",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Path to model.onnx for speaker embedding extractor",
|
||||||
|
)
|
||||||
|
parser.add_argument("--wav", type=str, required=True, help="Path to test.wav")
|
||||||
|
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
class OnnxSegmentationModel:
|
||||||
|
def __init__(self, filename):
|
||||||
|
session_opts = ort.SessionOptions()
|
||||||
|
session_opts.inter_op_num_threads = 1
|
||||||
|
session_opts.intra_op_num_threads = 1
|
||||||
|
|
||||||
|
self.session_opts = session_opts
|
||||||
|
|
||||||
|
self.model = ort.InferenceSession(
|
||||||
|
filename,
|
||||||
|
sess_options=self.session_opts,
|
||||||
|
providers=["CPUExecutionProvider"],
|
||||||
|
)
|
||||||
|
|
||||||
|
meta = self.model.get_modelmeta().custom_metadata_map
|
||||||
|
print(meta)
|
||||||
|
|
||||||
|
self.window_size = int(meta["window_size"])
|
||||||
|
self.sample_rate = int(meta["sample_rate"])
|
||||||
|
self.window_shift = int(0.1 * self.window_size)
|
||||||
|
self.receptive_field_size = int(meta["receptive_field_size"])
|
||||||
|
self.receptive_field_shift = int(meta["receptive_field_shift"])
|
||||||
|
self.num_speakers = int(meta["num_speakers"])
|
||||||
|
self.powerset_max_classes = int(meta["powerset_max_classes"])
|
||||||
|
self.num_classes = int(meta["num_classes"])
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x: (N, num_samples)
|
||||||
|
Returns:
|
||||||
|
A tensor of shape (N, num_frames, num_classes)
|
||||||
|
"""
|
||||||
|
x = np.expand_dims(x, axis=1)
|
||||||
|
|
||||||
|
(y,) = self.model.run(
|
||||||
|
[self.model.get_outputs()[0].name], {self.model.get_inputs()[0].name: x}
|
||||||
|
)
|
||||||
|
|
||||||
|
return y
|
||||||
|
|
||||||
|
|
||||||
|
def load_wav(filename, expected_sample_rate) -> np.ndarray:
|
||||||
|
audio, sample_rate = sf.read(filename, dtype="float32", always_2d=True)
|
||||||
|
audio = audio[:, 0] # only use the first channel
|
||||||
|
if sample_rate != expected_sample_rate:
|
||||||
|
audio = librosa.resample(
|
||||||
|
audio,
|
||||||
|
orig_sr=sample_rate,
|
||||||
|
target_sr=expected_sample_rate,
|
||||||
|
)
|
||||||
|
return audio
|
||||||
|
|
||||||
|
|
||||||
|
def get_powerset_mapping(num_classes, num_speakers, powerset_max_classes):
|
||||||
|
mapping = np.zeros((num_classes, num_speakers))
|
||||||
|
|
||||||
|
k = 1
|
||||||
|
for i in range(1, powerset_max_classes + 1):
|
||||||
|
if i == 1:
|
||||||
|
for j in range(0, num_speakers):
|
||||||
|
mapping[k, j] = 1
|
||||||
|
k += 1
|
||||||
|
elif i == 2:
|
||||||
|
for j in range(0, num_speakers):
|
||||||
|
for m in range(j + 1, num_speakers):
|
||||||
|
mapping[k, j] = 1
|
||||||
|
mapping[k, m] = 1
|
||||||
|
k += 1
|
||||||
|
elif i == 3:
|
||||||
|
raise RuntimeError("Unsupported")
|
||||||
|
|
||||||
|
return mapping
|
||||||
|
|
||||||
|
|
||||||
|
def to_multi_label(y, mapping):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
y: (num_chunks, num_frames, num_classes)
|
||||||
|
Returns:
|
||||||
|
A tensor of shape (num_chunks, num_frames, num_speakers)
|
||||||
|
"""
|
||||||
|
y = np.argmax(y, axis=-1)
|
||||||
|
labels = mapping[y.reshape(-1)].reshape(y.shape[0], y.shape[1], -1)
|
||||||
|
return labels
|
||||||
|
|
||||||
|
|
||||||
|
# speaker count per frame
|
||||||
|
def speaker_count(labels, seg_m):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
labels: (num_chunks, num_frames, num_speakers)
|
||||||
|
seg_m: Segmentation model
|
||||||
|
Returns:
|
||||||
|
A integer array of shape (num_total_frames,)
|
||||||
|
"""
|
||||||
|
labels = labels.sum(axis=-1)
|
||||||
|
# Now labels: (num_chunks, num_frames)
|
||||||
|
|
||||||
|
num_frames = (
|
||||||
|
int(
|
||||||
|
(seg_m.window_size + (labels.shape[0] - 1) * seg_m.window_shift)
|
||||||
|
/ seg_m.receptive_field_shift
|
||||||
|
)
|
||||||
|
+ 1
|
||||||
|
)
|
||||||
|
ans = np.zeros((num_frames,))
|
||||||
|
count = np.zeros((num_frames,))
|
||||||
|
|
||||||
|
for i in range(labels.shape[0]):
|
||||||
|
this_chunk = labels[i]
|
||||||
|
start = int(i * seg_m.window_shift / seg_m.receptive_field_shift + 0.5)
|
||||||
|
end = start + this_chunk.shape[0]
|
||||||
|
ans[start:end] += this_chunk
|
||||||
|
count[start:end] += 1
|
||||||
|
|
||||||
|
ans /= np.maximum(count, 1e-12)
|
||||||
|
|
||||||
|
return (ans + 0.5).astype(np.int8)
|
||||||
|
|
||||||
|
|
||||||
|
def load_speaker_embedding_model(filename):
|
||||||
|
config = sherpa_onnx.SpeakerEmbeddingExtractorConfig(
|
||||||
|
model=filename,
|
||||||
|
num_threads=1,
|
||||||
|
debug=0,
|
||||||
|
)
|
||||||
|
if not config.validate():
|
||||||
|
raise ValueError(f"Invalid config. {config}")
|
||||||
|
extractor = sherpa_onnx.SpeakerEmbeddingExtractor(config)
|
||||||
|
return extractor
|
||||||
|
|
||||||
|
|
||||||
|
def get_embeddings(embedding_filename, audio, labels, seg_m, exclude_overlap):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
embedding_filename: Path to the speaker embedding extractor model
|
||||||
|
audio: (num_samples,)
|
||||||
|
labels: (num_chunks, num_frames, num_speakers)
|
||||||
|
seg_m: segmentation model
|
||||||
|
Returns:
|
||||||
|
Return (num_chunks, num_speakers, embedding_dim)
|
||||||
|
"""
|
||||||
|
if exclude_overlap:
|
||||||
|
labels = labels * (labels.sum(axis=-1, keepdims=True) < 2)
|
||||||
|
|
||||||
|
extractor = load_speaker_embedding_model(embedding_filename)
|
||||||
|
buffer = np.empty(seg_m.window_size)
|
||||||
|
num_chunks, num_frames, num_speakers = labels.shape
|
||||||
|
|
||||||
|
ans_chunk_speaker_pair = []
|
||||||
|
ans_embeddings = []
|
||||||
|
|
||||||
|
for i in range(num_chunks):
|
||||||
|
labels_T = labels[i].T
|
||||||
|
# t: (num_speakers, num_frames)
|
||||||
|
|
||||||
|
sample_offset = i * seg_m.window_shift
|
||||||
|
|
||||||
|
for j in range(num_speakers):
|
||||||
|
frames = labels_T[j]
|
||||||
|
if frames.sum() < 10:
|
||||||
|
# skip segment less than 20 frames, i.e., about 0.2 seconds
|
||||||
|
continue
|
||||||
|
|
||||||
|
start = None
|
||||||
|
start_samples = 0
|
||||||
|
idx = 0
|
||||||
|
for k in range(num_frames):
|
||||||
|
if frames[k] != 0:
|
||||||
|
if start is None:
|
||||||
|
start = k
|
||||||
|
elif start is not None:
|
||||||
|
start_samples = (
|
||||||
|
int(start / num_frames * seg_m.window_size) + sample_offset
|
||||||
|
)
|
||||||
|
end_samples = (
|
||||||
|
int(k / num_frames * seg_m.window_size) + sample_offset
|
||||||
|
)
|
||||||
|
num_samples = end_samples - start_samples
|
||||||
|
buffer[idx : idx + num_samples] = audio[start_samples:end_samples]
|
||||||
|
idx += num_samples
|
||||||
|
|
||||||
|
start = None
|
||||||
|
if start is not None:
|
||||||
|
start_samples = (
|
||||||
|
int(start / num_frames * seg_m.window_size) + sample_offset
|
||||||
|
)
|
||||||
|
end_samples = int(k / num_frames * seg_m.window_size) + sample_offset
|
||||||
|
num_samples = end_samples - start_samples
|
||||||
|
buffer[idx : idx + num_samples] = audio[start_samples:end_samples]
|
||||||
|
idx += num_samples
|
||||||
|
|
||||||
|
stream = extractor.create_stream()
|
||||||
|
stream.accept_waveform(sample_rate=seg_m.sample_rate, waveform=buffer[:idx])
|
||||||
|
stream.input_finished()
|
||||||
|
|
||||||
|
assert extractor.is_ready(stream)
|
||||||
|
embedding = extractor.compute(stream)
|
||||||
|
embedding = np.array(embedding)
|
||||||
|
|
||||||
|
ans_chunk_speaker_pair.append([i, j])
|
||||||
|
ans_embeddings.append(embedding)
|
||||||
|
|
||||||
|
assert len(ans_chunk_speaker_pair) == len(ans_embeddings), (
|
||||||
|
len(ans_chunk_speaker_pair),
|
||||||
|
len(ans_embeddings),
|
||||||
|
)
|
||||||
|
return ans_chunk_speaker_pair, np.array(ans_embeddings)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = get_args()
|
||||||
|
assert Path(args.seg_model).is_file(), args.seg_model
|
||||||
|
assert Path(args.wav).is_file(), args.wav
|
||||||
|
|
||||||
|
seg_m = OnnxSegmentationModel(args.seg_model)
|
||||||
|
audio = load_wav(args.wav, seg_m.sample_rate)
|
||||||
|
# audio: (num_samples,)
|
||||||
|
|
||||||
|
num = (audio.shape[0] - seg_m.window_size) // seg_m.window_shift + 1
|
||||||
|
|
||||||
|
samples = as_strided(
|
||||||
|
audio,
|
||||||
|
shape=(num, seg_m.window_size),
|
||||||
|
strides=(seg_m.window_shift * audio.strides[0], audio.strides[0]),
|
||||||
|
)
|
||||||
|
|
||||||
|
# or use torch.Tensor.unfold
|
||||||
|
# samples = torch.from_numpy(audio).unfold(0, seg_m.window_size, seg_m.window_shift).numpy()
|
||||||
|
|
||||||
|
if (
|
||||||
|
audio.shape[0] < seg_m.window_size
|
||||||
|
or (audio.shape[0] - seg_m.window_size) % seg_m.window_shift > 0
|
||||||
|
):
|
||||||
|
has_last_chunk = True
|
||||||
|
else:
|
||||||
|
has_last_chunk = False
|
||||||
|
|
||||||
|
num_chunks = samples.shape[0]
|
||||||
|
batch_size = 32
|
||||||
|
output = []
|
||||||
|
for i in range(0, num_chunks, batch_size):
|
||||||
|
start = i
|
||||||
|
end = i + batch_size
|
||||||
|
# it's perfectly ok to use end > num_chunks
|
||||||
|
y = seg_m(samples[start:end])
|
||||||
|
output.append(y)
|
||||||
|
|
||||||
|
if has_last_chunk:
|
||||||
|
last_chunk = audio[num_chunks * seg_m.window_shift :] # noqa
|
||||||
|
pad_size = seg_m.window_size - last_chunk.shape[0]
|
||||||
|
last_chunk = np.pad(last_chunk, (0, pad_size))
|
||||||
|
last_chunk = np.expand_dims(last_chunk, axis=0)
|
||||||
|
y = seg_m(last_chunk)
|
||||||
|
output.append(y)
|
||||||
|
|
||||||
|
y = np.vstack(output)
|
||||||
|
# y: (num_chunks, num_frames, num_classes)
|
||||||
|
|
||||||
|
mapping = get_powerset_mapping(
|
||||||
|
num_classes=seg_m.num_classes,
|
||||||
|
num_speakers=seg_m.num_speakers,
|
||||||
|
powerset_max_classes=seg_m.powerset_max_classes,
|
||||||
|
)
|
||||||
|
labels = to_multi_label(y, mapping=mapping)
|
||||||
|
# labels: (num_chunks, num_frames, num_speakers)
|
||||||
|
|
||||||
|
inactive = (labels.sum(axis=1) == 0).astype(np.int8)
|
||||||
|
# inactive: (num_chunks, num_speakers)
|
||||||
|
|
||||||
|
speakers_per_frame = speaker_count(labels=labels, seg_m=seg_m)
|
||||||
|
# speakers_per_frame: (num_frames, speakers_per_frame)
|
||||||
|
|
||||||
|
if speakers_per_frame.max() == 0:
|
||||||
|
print("No speakers found in the audio file!")
|
||||||
|
return
|
||||||
|
|
||||||
|
# if users specify only 1 speaker for clustering, then return the
|
||||||
|
# result directly
|
||||||
|
|
||||||
|
# Now, get embeddings
|
||||||
|
chunk_speaker_pair, embeddings = get_embeddings(
|
||||||
|
args.speaker_embedding_model,
|
||||||
|
audio=audio,
|
||||||
|
labels=labels,
|
||||||
|
seg_m=seg_m,
|
||||||
|
# exclude_overlap=True,
|
||||||
|
exclude_overlap=False,
|
||||||
|
)
|
||||||
|
# chunk_speaker_pair: a list of (chunk_idx, speaker_idx)
|
||||||
|
# embeddings: (batch_size, embedding_dim)
|
||||||
|
|
||||||
|
# Please change num_clusters or threshold by yourself.
|
||||||
|
clustering_config = sherpa_onnx.FastClusteringConfig(num_clusters=2)
|
||||||
|
# clustering_config = sherpa_onnx.FastClusteringConfig(threshold=0.8)
|
||||||
|
clustering = sherpa_onnx.FastClustering(clustering_config)
|
||||||
|
cluster_labels = clustering(embeddings)
|
||||||
|
|
||||||
|
chunk_speaker_to_cluster = dict()
|
||||||
|
for (chunk_idx, speaker_idx), cluster_idx in zip(
|
||||||
|
chunk_speaker_pair, cluster_labels
|
||||||
|
):
|
||||||
|
if inactive[chunk_idx, speaker_idx] == 1:
|
||||||
|
print("skip ", chunk_idx, speaker_idx)
|
||||||
|
continue
|
||||||
|
chunk_speaker_to_cluster[(chunk_idx, speaker_idx)] = cluster_idx
|
||||||
|
|
||||||
|
num_speakers = max(cluster_labels) + 1
|
||||||
|
relabels = np.zeros((labels.shape[0], labels.shape[1], num_speakers))
|
||||||
|
for i in range(labels.shape[0]):
|
||||||
|
for j in range(labels.shape[1]):
|
||||||
|
for k in range(labels.shape[2]):
|
||||||
|
if (i, k) not in chunk_speaker_to_cluster:
|
||||||
|
continue
|
||||||
|
t = chunk_speaker_to_cluster[(i, k)]
|
||||||
|
|
||||||
|
if labels[i, j, k] == 1:
|
||||||
|
relabels[i, j, t] = 1
|
||||||
|
|
||||||
|
num_frames = (
|
||||||
|
int(
|
||||||
|
(seg_m.window_size + (relabels.shape[0] - 1) * seg_m.window_shift)
|
||||||
|
/ seg_m.receptive_field_shift
|
||||||
|
)
|
||||||
|
+ 1
|
||||||
|
)
|
||||||
|
|
||||||
|
count = np.zeros((num_frames, relabels.shape[-1]))
|
||||||
|
for i in range(relabels.shape[0]):
|
||||||
|
this_chunk = relabels[i]
|
||||||
|
start = int(i * seg_m.window_shift / seg_m.receptive_field_shift + 0.5)
|
||||||
|
end = start + this_chunk.shape[0]
|
||||||
|
count[start:end] += this_chunk
|
||||||
|
|
||||||
|
if has_last_chunk:
|
||||||
|
stop_frame = int(audio.shape[0] / seg_m.receptive_field_shift)
|
||||||
|
count = count[:stop_frame]
|
||||||
|
|
||||||
|
sorted_count = np.argsort(-count, axis=-1)
|
||||||
|
final = np.zeros((count.shape[0], count.shape[1]))
|
||||||
|
|
||||||
|
for i, (c, sc) in enumerate(zip(speakers_per_frame, sorted_count)):
|
||||||
|
for k in range(c):
|
||||||
|
final[i, sc[k]] = 1
|
||||||
|
|
||||||
|
min_duration_off = 0.5
|
||||||
|
min_duration_on = 0.3
|
||||||
|
onset = 0.5
|
||||||
|
offset = 0.5
|
||||||
|
# final: (num_frames, num_speakers)
|
||||||
|
|
||||||
|
final = final.T
|
||||||
|
for kk in range(final.shape[0]):
|
||||||
|
segment_list = []
|
||||||
|
frames = final[kk]
|
||||||
|
|
||||||
|
is_active = frames[0] > onset
|
||||||
|
|
||||||
|
start = None
|
||||||
|
if is_active:
|
||||||
|
start = 0
|
||||||
|
scale = seg_m.receptive_field_shift / seg_m.sample_rate
|
||||||
|
scale_offset = seg_m.receptive_field_size / seg_m.sample_rate * 0.5
|
||||||
|
for i in range(1, len(frames)):
|
||||||
|
if is_active:
|
||||||
|
if frames[i] < offset:
|
||||||
|
segment = Segment(
|
||||||
|
start=start * scale + scale_offset,
|
||||||
|
end=i * scale + scale_offset,
|
||||||
|
speaker=kk,
|
||||||
|
)
|
||||||
|
segment_list.append(segment)
|
||||||
|
is_active = False
|
||||||
|
else:
|
||||||
|
if frames[i] > onset:
|
||||||
|
start = i
|
||||||
|
is_active = True
|
||||||
|
|
||||||
|
if is_active:
|
||||||
|
segment = Segment(
|
||||||
|
start=start * scale + scale_offset,
|
||||||
|
end=(len(frames) - 1) * scale + scale_offset,
|
||||||
|
speaker=kk,
|
||||||
|
)
|
||||||
|
segment_list.append(segment)
|
||||||
|
|
||||||
|
if len(segment_list) > 1:
|
||||||
|
merge_segment_list(segment_list, min_duration_off=min_duration_off)
|
||||||
|
for s in segment_list:
|
||||||
|
if s.duration < min_duration_on:
|
||||||
|
continue
|
||||||
|
print(s)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
86
scripts/pyannote/segmentation/speaker-diarization-torch.py
Executable file
86
scripts/pyannote/segmentation/speaker-diarization-torch.py
Executable file
@@ -0,0 +1,86 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
"""
|
||||||
|
Please refer to
|
||||||
|
https://github.com/k2-fsa/sherpa-onnx/blob/master/.github/workflows/speaker-diarization.yaml
|
||||||
|
for usages.
|
||||||
|
"""
|
||||||
|
|
||||||
|
"""
|
||||||
|
1. Go to https://huggingface.co/hbredin/wespeaker-voxceleb-resnet34-LM/tree/main
|
||||||
|
wget https://huggingface.co/hbredin/wespeaker-voxceleb-resnet34-LM/resolve/main/speaker-embedding.onnx
|
||||||
|
|
||||||
|
2. Change line 166 of pyannote/audio/pipelines/speaker_diarization.py
|
||||||
|
|
||||||
|
```
|
||||||
|
# self._embedding = PretrainedSpeakerEmbedding(
|
||||||
|
# self.embedding, use_auth_token=use_auth_token
|
||||||
|
# )
|
||||||
|
self._embedding = embedding
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from pyannote.audio import Model
|
||||||
|
from pyannote.audio.pipelines import SpeakerDiarization as SpeakerDiarizationPipeline
|
||||||
|
from pyannote.audio.pipelines.speaker_verification import (
|
||||||
|
ONNXWeSpeakerPretrainedSpeakerEmbedding,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--wav", type=str, required=True, help="Path to test.wav")
|
||||||
|
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def build_pipeline():
|
||||||
|
embedding_filename = "./speaker-embedding.onnx"
|
||||||
|
if Path(embedding_filename).is_file():
|
||||||
|
# You need to modify line 166
|
||||||
|
# of pyannote/audio/pipelines/speaker_diarization.py
|
||||||
|
# Please see the comments at the start of this script for details
|
||||||
|
embedding = ONNXWeSpeakerPretrainedSpeakerEmbedding(embedding_filename)
|
||||||
|
else:
|
||||||
|
embedding = "hbredin/wespeaker-voxceleb-resnet34-LM"
|
||||||
|
|
||||||
|
pt_filename = "./pytorch_model.bin"
|
||||||
|
segmentation = Model.from_pretrained(pt_filename)
|
||||||
|
segmentation.eval()
|
||||||
|
|
||||||
|
pipeline = SpeakerDiarizationPipeline(
|
||||||
|
segmentation=segmentation,
|
||||||
|
embedding=embedding,
|
||||||
|
embedding_exclude_overlap=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
params = {
|
||||||
|
"clustering": {
|
||||||
|
"method": "centroid",
|
||||||
|
"min_cluster_size": 12,
|
||||||
|
"threshold": 0.7045654963945799,
|
||||||
|
},
|
||||||
|
"segmentation": {"min_duration_off": 0.5},
|
||||||
|
}
|
||||||
|
|
||||||
|
pipeline.instantiate(params)
|
||||||
|
return pipeline
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def main():
|
||||||
|
args = get_args()
|
||||||
|
assert Path(args.wav).is_file(), args.wav
|
||||||
|
pipeline = build_pipeline()
|
||||||
|
print(pipeline)
|
||||||
|
t = pipeline(args.wav)
|
||||||
|
print(type(t))
|
||||||
|
print(t)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -52,7 +52,7 @@ class FastClustering::Impl {
|
|||||||
std::vector<double> height(num_rows - 1);
|
std::vector<double> height(num_rows - 1);
|
||||||
|
|
||||||
fastclustercpp::hclust_fast(num_rows, distance.data(),
|
fastclustercpp::hclust_fast(num_rows, distance.data(),
|
||||||
fastclustercpp::HCLUST_METHOD_SINGLE,
|
fastclustercpp::HCLUST_METHOD_COMPLETE,
|
||||||
merge.data(), height.data());
|
merge.data(), height.data());
|
||||||
|
|
||||||
std::vector<int32_t> labels(num_rows);
|
std::vector<int32_t> labels(num_rows);
|
||||||
|
|||||||
Reference in New Issue
Block a user