C++ API for speaker diarization (#1396)
This commit is contained in:
133
sherpa-onnx/csrc/sherpa-onnx-offline-speaker-diarization.cc
Normal file
133
sherpa-onnx/csrc/sherpa-onnx-offline-speaker-diarization.cc
Normal file
@@ -0,0 +1,133 @@
|
||||
// sherpa-onnx/csrc/sherpa-onnx-offline-speaker-diarization.cc
|
||||
//
|
||||
// Copyright (c) 2024 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-speaker-diarization.h"
|
||||
#include "sherpa-onnx/csrc/parse-options.h"
|
||||
#include "sherpa-onnx/csrc/wave-reader.h"
|
||||
|
||||
static int32_t ProgressCallback(int32_t processed_chunks, int32_t num_chunks,
|
||||
void *arg) {
|
||||
float progress = 100.0 * processed_chunks / num_chunks;
|
||||
fprintf(stderr, "progress %.2f%%\n", progress);
|
||||
|
||||
// the return value is currently ignored
|
||||
return 0;
|
||||
}
|
||||
|
||||
int main(int32_t argc, char *argv[]) {
|
||||
const char *kUsageMessage = R"usage(
|
||||
Offline/Non-streaming speaker diarization with sherpa-onnx
|
||||
Usage example:
|
||||
|
||||
Step 1: Download a speaker segmentation model
|
||||
|
||||
Please visit https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-segmentation-models
|
||||
for a list of available models. The following is an example
|
||||
|
||||
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/sherpa-onnx-pyannote-segmentation-3-0.tar.bz2
|
||||
tar xvf sherpa-onnx-pyannote-segmentation-3-0.tar.bz2
|
||||
rm sherpa-onnx-pyannote-segmentation-3-0.tar.bz2
|
||||
|
||||
Step 2: Download a speaker embedding extractor model
|
||||
|
||||
Please visit https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-recongition-models
|
||||
for a list of available models. The following is an example
|
||||
|
||||
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx
|
||||
|
||||
Step 3. Download test wave files
|
||||
|
||||
Please visit https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-segmentation-models
|
||||
for a list of available test wave files. The following is an example
|
||||
|
||||
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/0-four-speakers-zh.wav
|
||||
|
||||
Step 4. Build sherpa-onnx
|
||||
|
||||
Step 5. Run it
|
||||
|
||||
./bin/sherpa-onnx-offline-speaker-diarization \
|
||||
--clustering.num-clusters=4 \
|
||||
--segmentation.pyannote-model=./sherpa-onnx-pyannote-segmentation-3-0/model.onnx \
|
||||
--embedding.model=./3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx \
|
||||
./0-four-speakers-zh.wav
|
||||
|
||||
Since we know that there are four speakers in the test wave file, we use
|
||||
--clustering.num-clusters=4 in the above example.
|
||||
|
||||
If we don't know number of speakers in the given wave file, we can use
|
||||
the argument --clustering.cluster-threshold. The following is an example:
|
||||
|
||||
./bin/sherpa-onnx-offline-speaker-diarization \
|
||||
--clustering.cluster-threshold=0.90 \
|
||||
--segmentation.pyannote-model=./sherpa-onnx-pyannote-segmentation-3-0/model.onnx \
|
||||
--embedding.model=./3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx \
|
||||
./0-four-speakers-zh.wav
|
||||
|
||||
A larger threshold leads to few clusters, i.e., few speakers;
|
||||
a smaller threshold leads to more clusters, i.e., more speakers
|
||||
)usage";
|
||||
sherpa_onnx::OfflineSpeakerDiarizationConfig config;
|
||||
sherpa_onnx::ParseOptions po(kUsageMessage);
|
||||
config.Register(&po);
|
||||
po.Read(argc, argv);
|
||||
|
||||
std::cout << config.ToString() << "\n";
|
||||
|
||||
if (!config.Validate()) {
|
||||
po.PrintUsage();
|
||||
std::cerr << "Errors in config!\n";
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (po.NumArgs() != 1) {
|
||||
std::cerr << "Error: Please provide exactly 1 wave file.\n\n";
|
||||
po.PrintUsage();
|
||||
return -1;
|
||||
}
|
||||
|
||||
sherpa_onnx::OfflineSpeakerDiarization sd(config);
|
||||
|
||||
std::cout << "Started\n";
|
||||
const auto begin = std::chrono::steady_clock::now();
|
||||
const std::string wav_filename = po.GetArg(1);
|
||||
int32_t sample_rate = -1;
|
||||
bool is_ok = false;
|
||||
const std::vector<float> samples =
|
||||
sherpa_onnx::ReadWave(wav_filename, &sample_rate, &is_ok);
|
||||
if (!is_ok) {
|
||||
std::cerr << "Failed to read " << wav_filename.c_str() << "\n";
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (sample_rate != sd.SampleRate()) {
|
||||
std::cerr << "Expect sample rate " << sd.SampleRate()
|
||||
<< ". Given: " << sample_rate << "\n";
|
||||
return -1;
|
||||
}
|
||||
|
||||
float duration = samples.size() / static_cast<float>(sample_rate);
|
||||
|
||||
auto result =
|
||||
sd.Process(samples.data(), samples.size(), ProgressCallback, nullptr)
|
||||
.SortByStartTime();
|
||||
|
||||
for (const auto &r : result) {
|
||||
std::cout << r.ToString() << "\n";
|
||||
}
|
||||
|
||||
const auto end = std::chrono::steady_clock::now();
|
||||
float elapsed_seconds =
|
||||
std::chrono::duration_cast<std::chrono::milliseconds>(end - begin)
|
||||
.count() /
|
||||
1000.;
|
||||
|
||||
fprintf(stderr, "Duration : %.3f s\n", duration);
|
||||
fprintf(stderr, "Elapsed seconds: %.3f s\n", elapsed_seconds);
|
||||
float rtf = elapsed_seconds / duration;
|
||||
fprintf(stderr, "Real time factor (RTF): %.3f / %.3f = %.3f\n",
|
||||
elapsed_seconds, duration, rtf);
|
||||
|
||||
return 0;
|
||||
}
|
||||
Reference in New Issue
Block a user