This repository has been archived on 2025-08-26. You can view files and clone it, but cannot push or open issues or pull requests.
Files
enginex-mr_series-sherpa-onnx/sherpa-onnx/csrc/sherpa-onnx-offline-speaker-diarization.cc
2024-10-09 17:10:03 +08:00

134 lines
4.5 KiB
C++

// sherpa-onnx/csrc/sherpa-onnx-offline-speaker-diarization.cc
//
// Copyright (c) 2024 Xiaomi Corporation
#include "sherpa-onnx/csrc/offline-speaker-diarization.h"
#include "sherpa-onnx/csrc/parse-options.h"
#include "sherpa-onnx/csrc/wave-reader.h"
static int32_t ProgressCallback(int32_t processed_chunks, int32_t num_chunks,
void *) {
float progress = 100.0 * processed_chunks / num_chunks;
fprintf(stderr, "progress %.2f%%\n", progress);
// the return value is currently ignored
return 0;
}
int main(int32_t argc, char *argv[]) {
const char *kUsageMessage = R"usage(
Offline/Non-streaming speaker diarization with sherpa-onnx
Usage example:
Step 1: Download a speaker segmentation model
Please visit https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-segmentation-models
for a list of available models. The following is an example
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/sherpa-onnx-pyannote-segmentation-3-0.tar.bz2
tar xvf sherpa-onnx-pyannote-segmentation-3-0.tar.bz2
rm sherpa-onnx-pyannote-segmentation-3-0.tar.bz2
Step 2: Download a speaker embedding extractor model
Please visit https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-recongition-models
for a list of available models. The following is an example
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx
Step 3. Download test wave files
Please visit https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-segmentation-models
for a list of available test wave files. The following is an example
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/0-four-speakers-zh.wav
Step 4. Build sherpa-onnx
Step 5. Run it
./bin/sherpa-onnx-offline-speaker-diarization \
--clustering.num-clusters=4 \
--segmentation.pyannote-model=./sherpa-onnx-pyannote-segmentation-3-0/model.onnx \
--embedding.model=./3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx \
./0-four-speakers-zh.wav
Since we know that there are four speakers in the test wave file, we use
--clustering.num-clusters=4 in the above example.
If we don't know number of speakers in the given wave file, we can use
the argument --clustering.cluster-threshold. The following is an example:
./bin/sherpa-onnx-offline-speaker-diarization \
--clustering.cluster-threshold=0.90 \
--segmentation.pyannote-model=./sherpa-onnx-pyannote-segmentation-3-0/model.onnx \
--embedding.model=./3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx \
./0-four-speakers-zh.wav
A larger threshold leads to few clusters, i.e., few speakers;
a smaller threshold leads to more clusters, i.e., more speakers
)usage";
sherpa_onnx::OfflineSpeakerDiarizationConfig config;
sherpa_onnx::ParseOptions po(kUsageMessage);
config.Register(&po);
po.Read(argc, argv);
std::cout << config.ToString() << "\n";
if (!config.Validate()) {
po.PrintUsage();
std::cerr << "Errors in config!\n";
return -1;
}
if (po.NumArgs() != 1) {
std::cerr << "Error: Please provide exactly 1 wave file.\n\n";
po.PrintUsage();
return -1;
}
sherpa_onnx::OfflineSpeakerDiarization sd(config);
std::cout << "Started\n";
const auto begin = std::chrono::steady_clock::now();
const std::string wav_filename = po.GetArg(1);
int32_t sample_rate = -1;
bool is_ok = false;
const std::vector<float> samples =
sherpa_onnx::ReadWave(wav_filename, &sample_rate, &is_ok);
if (!is_ok) {
std::cerr << "Failed to read " << wav_filename.c_str() << "\n";
return -1;
}
if (sample_rate != sd.SampleRate()) {
std::cerr << "Expect sample rate " << sd.SampleRate()
<< ". Given: " << sample_rate << "\n";
return -1;
}
float duration = samples.size() / static_cast<float>(sample_rate);
auto result =
sd.Process(samples.data(), samples.size(), ProgressCallback, nullptr)
.SortByStartTime();
for (const auto &r : result) {
std::cout << r.ToString() << "\n";
}
const auto end = std::chrono::steady_clock::now();
float elapsed_seconds =
std::chrono::duration_cast<std::chrono::milliseconds>(end - begin)
.count() /
1000.;
fprintf(stderr, "Duration : %.3f s\n", duration);
fprintf(stderr, "Elapsed seconds: %.3f s\n", elapsed_seconds);
float rtf = elapsed_seconds / duration;
fprintf(stderr, "Real time factor (RTF): %.3f / %.3f = %.3f\n",
elapsed_seconds, duration, rtf);
return 0;
}