134 lines
4.5 KiB
C++
134 lines
4.5 KiB
C++
// sherpa-onnx/csrc/sherpa-onnx-offline-speaker-diarization.cc
|
|
//
|
|
// Copyright (c) 2024 Xiaomi Corporation
|
|
|
|
#include "sherpa-onnx/csrc/offline-speaker-diarization.h"
|
|
#include "sherpa-onnx/csrc/parse-options.h"
|
|
#include "sherpa-onnx/csrc/wave-reader.h"
|
|
|
|
static int32_t ProgressCallback(int32_t processed_chunks, int32_t num_chunks,
|
|
void *arg) {
|
|
float progress = 100.0 * processed_chunks / num_chunks;
|
|
fprintf(stderr, "progress %.2f%%\n", progress);
|
|
|
|
// the return value is currently ignored
|
|
return 0;
|
|
}
|
|
|
|
int main(int32_t argc, char *argv[]) {
|
|
const char *kUsageMessage = R"usage(
|
|
Offline/Non-streaming speaker diarization with sherpa-onnx
|
|
Usage example:
|
|
|
|
Step 1: Download a speaker segmentation model
|
|
|
|
Please visit https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-segmentation-models
|
|
for a list of available models. The following is an example
|
|
|
|
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/sherpa-onnx-pyannote-segmentation-3-0.tar.bz2
|
|
tar xvf sherpa-onnx-pyannote-segmentation-3-0.tar.bz2
|
|
rm sherpa-onnx-pyannote-segmentation-3-0.tar.bz2
|
|
|
|
Step 2: Download a speaker embedding extractor model
|
|
|
|
Please visit https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-recongition-models
|
|
for a list of available models. The following is an example
|
|
|
|
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx
|
|
|
|
Step 3. Download test wave files
|
|
|
|
Please visit https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-segmentation-models
|
|
for a list of available test wave files. The following is an example
|
|
|
|
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/0-four-speakers-zh.wav
|
|
|
|
Step 4. Build sherpa-onnx
|
|
|
|
Step 5. Run it
|
|
|
|
./bin/sherpa-onnx-offline-speaker-diarization \
|
|
--clustering.num-clusters=4 \
|
|
--segmentation.pyannote-model=./sherpa-onnx-pyannote-segmentation-3-0/model.onnx \
|
|
--embedding.model=./3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx \
|
|
./0-four-speakers-zh.wav
|
|
|
|
Since we know that there are four speakers in the test wave file, we use
|
|
--clustering.num-clusters=4 in the above example.
|
|
|
|
If we don't know number of speakers in the given wave file, we can use
|
|
the argument --clustering.cluster-threshold. The following is an example:
|
|
|
|
./bin/sherpa-onnx-offline-speaker-diarization \
|
|
--clustering.cluster-threshold=0.90 \
|
|
--segmentation.pyannote-model=./sherpa-onnx-pyannote-segmentation-3-0/model.onnx \
|
|
--embedding.model=./3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx \
|
|
./0-four-speakers-zh.wav
|
|
|
|
A larger threshold leads to few clusters, i.e., few speakers;
|
|
a smaller threshold leads to more clusters, i.e., more speakers
|
|
)usage";
|
|
sherpa_onnx::OfflineSpeakerDiarizationConfig config;
|
|
sherpa_onnx::ParseOptions po(kUsageMessage);
|
|
config.Register(&po);
|
|
po.Read(argc, argv);
|
|
|
|
std::cout << config.ToString() << "\n";
|
|
|
|
if (!config.Validate()) {
|
|
po.PrintUsage();
|
|
std::cerr << "Errors in config!\n";
|
|
return -1;
|
|
}
|
|
|
|
if (po.NumArgs() != 1) {
|
|
std::cerr << "Error: Please provide exactly 1 wave file.\n\n";
|
|
po.PrintUsage();
|
|
return -1;
|
|
}
|
|
|
|
sherpa_onnx::OfflineSpeakerDiarization sd(config);
|
|
|
|
std::cout << "Started\n";
|
|
const auto begin = std::chrono::steady_clock::now();
|
|
const std::string wav_filename = po.GetArg(1);
|
|
int32_t sample_rate = -1;
|
|
bool is_ok = false;
|
|
const std::vector<float> samples =
|
|
sherpa_onnx::ReadWave(wav_filename, &sample_rate, &is_ok);
|
|
if (!is_ok) {
|
|
std::cerr << "Failed to read " << wav_filename.c_str() << "\n";
|
|
return -1;
|
|
}
|
|
|
|
if (sample_rate != sd.SampleRate()) {
|
|
std::cerr << "Expect sample rate " << sd.SampleRate()
|
|
<< ". Given: " << sample_rate << "\n";
|
|
return -1;
|
|
}
|
|
|
|
float duration = samples.size() / static_cast<float>(sample_rate);
|
|
|
|
auto result =
|
|
sd.Process(samples.data(), samples.size(), ProgressCallback, nullptr)
|
|
.SortByStartTime();
|
|
|
|
for (const auto &r : result) {
|
|
std::cout << r.ToString() << "\n";
|
|
}
|
|
|
|
const auto end = std::chrono::steady_clock::now();
|
|
float elapsed_seconds =
|
|
std::chrono::duration_cast<std::chrono::milliseconds>(end - begin)
|
|
.count() /
|
|
1000.;
|
|
|
|
fprintf(stderr, "Duration : %.3f s\n", duration);
|
|
fprintf(stderr, "Elapsed seconds: %.3f s\n", elapsed_seconds);
|
|
float rtf = elapsed_seconds / duration;
|
|
fprintf(stderr, "Real time factor (RTF): %.3f / %.3f = %.3f\n",
|
|
elapsed_seconds, duration, rtf);
|
|
|
|
return 0;
|
|
}
|