C++ API for speaker diarization (#1396)
This commit is contained in:
@@ -164,6 +164,12 @@ if(SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION)
|
||||
list(APPEND sources
|
||||
fast-clustering-config.cc
|
||||
fast-clustering.cc
|
||||
offline-speaker-diarization-impl.cc
|
||||
offline-speaker-diarization-result.cc
|
||||
offline-speaker-diarization.cc
|
||||
offline-speaker-segmentation-model-config.cc
|
||||
offline-speaker-segmentation-pyannote-model-config.cc
|
||||
offline-speaker-segmentation-pyannote-model.cc
|
||||
)
|
||||
endif()
|
||||
|
||||
@@ -260,6 +266,10 @@ if(SHERPA_ONNX_ENABLE_BINARY)
|
||||
add_executable(sherpa-onnx-offline-tts sherpa-onnx-offline-tts.cc)
|
||||
endif()
|
||||
|
||||
if(SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION)
|
||||
add_executable(sherpa-onnx-offline-speaker-diarization sherpa-onnx-offline-speaker-diarization.cc)
|
||||
endif()
|
||||
|
||||
set(main_exes
|
||||
sherpa-onnx
|
||||
sherpa-onnx-keyword-spotter
|
||||
@@ -276,6 +286,12 @@ if(SHERPA_ONNX_ENABLE_BINARY)
|
||||
)
|
||||
endif()
|
||||
|
||||
if(SHERPA_ONNX_ENABLE_SPEAKER_DIARIZATION)
|
||||
list(APPEND main_exes
|
||||
sherpa-onnx-offline-speaker-diarization
|
||||
)
|
||||
endif()
|
||||
|
||||
foreach(exe IN LISTS main_exes)
|
||||
target_link_libraries(${exe} sherpa-onnx-core)
|
||||
endforeach()
|
||||
|
||||
@@ -21,18 +21,16 @@ std::string FastClusteringConfig::ToString() const {
|
||||
}
|
||||
|
||||
void FastClusteringConfig::Register(ParseOptions *po) {
|
||||
std::string prefix = "ctc";
|
||||
ParseOptions p(prefix, po);
|
||||
po->Register(
|
||||
"num-clusters", &num_clusters,
|
||||
"Number of cluster. If greater than 0, then cluster threshold is "
|
||||
"ignored. Please provide it if you know the actual number of "
|
||||
"clusters in advance.");
|
||||
|
||||
p.Register("num-clusters", &num_clusters,
|
||||
"Number of cluster. If greater than 0, then --cluster-thresold is "
|
||||
"ignored. Please provide it if you know the actual number of "
|
||||
"clusters in advance.");
|
||||
|
||||
p.Register("cluster-threshold", &threshold,
|
||||
"If --num-clusters is not specified, then it specifies the "
|
||||
"distance threshold for clustering. smaller value -> more "
|
||||
"clusters. larger value -> fewer clusters");
|
||||
po->Register("cluster-threshold", &threshold,
|
||||
"If num_clusters is not specified, then it specifies the "
|
||||
"distance threshold for clustering. smaller value -> more "
|
||||
"clusters. larger value -> fewer clusters");
|
||||
}
|
||||
|
||||
bool FastClusteringConfig::Validate() const {
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
#ifndef SHERPA_ONNX_CSRC_MACROS_H_
|
||||
#define SHERPA_ONNX_CSRC_MACROS_H_
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
|
||||
#if __ANDROID_API__ >= 8
|
||||
#include "android/log.h"
|
||||
@@ -169,4 +170,6 @@
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define SHERPA_ONNX_EXIT(code) exit(code)
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_MACROS_H_
|
||||
|
||||
@@ -9,6 +9,7 @@
|
||||
#include <utility>
|
||||
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
#include "sherpa-onnx/csrc/session.h"
|
||||
#include "sherpa-onnx/csrc/text-utils.h"
|
||||
|
||||
|
||||
26
sherpa-onnx/csrc/offline-speaker-diarization-impl.cc
Normal file
26
sherpa-onnx/csrc/offline-speaker-diarization-impl.cc
Normal file
@@ -0,0 +1,26 @@
|
||||
// sherpa-onnx/csrc/offline-speaker-diarization-impl.cc
|
||||
//
|
||||
// Copyright (c) 2024 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-speaker-diarization-impl.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/offline-speaker-diarization-pyannote-impl.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
std::unique_ptr<OfflineSpeakerDiarizationImpl>
|
||||
OfflineSpeakerDiarizationImpl::Create(
|
||||
const OfflineSpeakerDiarizationConfig &config) {
|
||||
if (!config.segmentation.pyannote.model.empty()) {
|
||||
return std::make_unique<OfflineSpeakerDiarizationPyannoteImpl>(config);
|
||||
}
|
||||
|
||||
SHERPA_ONNX_LOGE("Please specify a speaker segmentation model.");
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
31
sherpa-onnx/csrc/offline-speaker-diarization-impl.h
Normal file
31
sherpa-onnx/csrc/offline-speaker-diarization-impl.h
Normal file
@@ -0,0 +1,31 @@
|
||||
// sherpa-onnx/csrc/offline-speaker-diarization-impl.h
|
||||
//
|
||||
// Copyright (c) 2024 Xiaomi Corporation
|
||||
|
||||
#ifndef SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_DIARIZATION_IMPL_H_
|
||||
#define SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_DIARIZATION_IMPL_H_
|
||||
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-speaker-diarization.h"
|
||||
namespace sherpa_onnx {
|
||||
|
||||
class OfflineSpeakerDiarizationImpl {
|
||||
public:
|
||||
static std::unique_ptr<OfflineSpeakerDiarizationImpl> Create(
|
||||
const OfflineSpeakerDiarizationConfig &config);
|
||||
|
||||
virtual ~OfflineSpeakerDiarizationImpl() = default;
|
||||
|
||||
virtual int32_t SampleRate() const = 0;
|
||||
|
||||
virtual OfflineSpeakerDiarizationResult Process(
|
||||
const float *audio, int32_t n,
|
||||
OfflineSpeakerDiarizationProgressCallback callback = nullptr,
|
||||
void *callback_arg = nullptr) const = 0;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_DIARIZATION_IMPL_H_
|
||||
644
sherpa-onnx/csrc/offline-speaker-diarization-pyannote-impl.h
Normal file
644
sherpa-onnx/csrc/offline-speaker-diarization-pyannote-impl.h
Normal file
@@ -0,0 +1,644 @@
|
||||
// sherpa-onnx/csrc/offline-speaker-diarization-pyannote-impl.h
|
||||
//
|
||||
// Copyright (c) 2024 Xiaomi Corporation
|
||||
#ifndef SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_DIARIZATION_PYANNOTE_IMPL_H_
|
||||
#define SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_DIARIZATION_PYANNOTE_IMPL_H_
|
||||
|
||||
#include <algorithm>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "Eigen/Dense"
|
||||
#include "sherpa-onnx/csrc/fast-clustering.h"
|
||||
#include "sherpa-onnx/csrc/math.h"
|
||||
#include "sherpa-onnx/csrc/offline-speaker-diarization-impl.h"
|
||||
#include "sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model.h"
|
||||
#include "sherpa-onnx/csrc/speaker-embedding-extractor.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
namespace { // NOLINT
|
||||
|
||||
// copied from https://github.com/k2-fsa/k2/blob/master/k2/csrc/host/util.h#L41
|
||||
template <class T>
|
||||
inline void hash_combine(std::size_t *seed, const T &v) { // NOLINT
|
||||
std::hash<T> hasher;
|
||||
*seed ^= hasher(v) + 0x9e3779b9 + ((*seed) << 6) + ((*seed) >> 2); // NOLINT
|
||||
}
|
||||
|
||||
// copied from https://github.com/k2-fsa/k2/blob/master/k2/csrc/host/util.h#L47
|
||||
struct PairHash {
|
||||
template <class T1, class T2>
|
||||
std::size_t operator()(const std::pair<T1, T2> &pair) const {
|
||||
std::size_t result = 0;
|
||||
hash_combine(&result, pair.first);
|
||||
hash_combine(&result, pair.second);
|
||||
return result;
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
using Matrix2D =
|
||||
Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
|
||||
|
||||
using Matrix2DInt32 =
|
||||
Eigen::Matrix<int32_t, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
|
||||
|
||||
using FloatRowVector = Eigen::Matrix<float, 1, Eigen::Dynamic>;
|
||||
using Int32RowVector = Eigen::Matrix<int32_t, 1, Eigen::Dynamic>;
|
||||
|
||||
using Int32Pair = std::pair<int32_t, int32_t>;
|
||||
|
||||
class OfflineSpeakerDiarizationPyannoteImpl
|
||||
: public OfflineSpeakerDiarizationImpl {
|
||||
public:
|
||||
~OfflineSpeakerDiarizationPyannoteImpl() override = default;
|
||||
|
||||
explicit OfflineSpeakerDiarizationPyannoteImpl(
|
||||
const OfflineSpeakerDiarizationConfig &config)
|
||||
: config_(config),
|
||||
segmentation_model_(config_.segmentation),
|
||||
embedding_extractor_(config_.embedding),
|
||||
clustering_(config_.clustering) {
|
||||
Init();
|
||||
}
|
||||
|
||||
int32_t SampleRate() const override {
|
||||
const auto &meta_data = segmentation_model_.GetModelMetaData();
|
||||
|
||||
return meta_data.sample_rate;
|
||||
}
|
||||
|
||||
OfflineSpeakerDiarizationResult Process(
|
||||
const float *audio, int32_t n,
|
||||
OfflineSpeakerDiarizationProgressCallback callback = nullptr,
|
||||
void *callback_arg = nullptr) const override {
|
||||
std::vector<Matrix2D> segmentations = RunSpeakerSegmentationModel(audio, n);
|
||||
// segmentations[i] is for chunk_i
|
||||
// Each matrix is of shape (num_frames, num_powerset_classes)
|
||||
if (segmentations.empty()) {
|
||||
return {};
|
||||
}
|
||||
|
||||
std::vector<Matrix2DInt32> labels;
|
||||
labels.reserve(segmentations.size());
|
||||
|
||||
for (const auto &m : segmentations) {
|
||||
labels.push_back(ToMultiLabel(m));
|
||||
}
|
||||
|
||||
segmentations.clear();
|
||||
|
||||
// labels[i] is a 0-1 matrix of shape (num_frames, num_speakers)
|
||||
|
||||
// speaker count per frame
|
||||
Int32RowVector speakers_per_frame = ComputeSpeakersPerFrame(labels);
|
||||
|
||||
if (speakers_per_frame.maxCoeff() == 0) {
|
||||
SHERPA_ONNX_LOGE("No speakers found in the audio samples");
|
||||
return {};
|
||||
}
|
||||
|
||||
auto chunk_speaker_samples_list_pair = GetChunkSpeakerSampleIndexes(labels);
|
||||
Matrix2D embeddings =
|
||||
ComputeEmbeddings(audio, n, chunk_speaker_samples_list_pair.second,
|
||||
callback, callback_arg);
|
||||
|
||||
std::vector<int32_t> cluster_labels = clustering_.Cluster(
|
||||
&embeddings(0, 0), embeddings.rows(), embeddings.cols());
|
||||
|
||||
int32_t max_cluster_index =
|
||||
*std::max_element(cluster_labels.begin(), cluster_labels.end());
|
||||
|
||||
auto chunk_speaker_to_cluster = ConvertChunkSpeakerToCluster(
|
||||
chunk_speaker_samples_list_pair.first, cluster_labels);
|
||||
|
||||
auto new_labels =
|
||||
ReLabel(labels, max_cluster_index, chunk_speaker_to_cluster);
|
||||
|
||||
Matrix2DInt32 speaker_count = ComputeSpeakerCount(new_labels, n);
|
||||
|
||||
Matrix2DInt32 final_labels =
|
||||
FinalizeLabels(speaker_count, speakers_per_frame);
|
||||
|
||||
auto result = ComputeResult(final_labels);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
private:
|
||||
void Init() { InitPowersetMapping(); }
|
||||
|
||||
// see also
|
||||
// https://github.com/pyannote/pyannote-audio/blob/develop/pyannote/audio/utils/powerset.py#L68
|
||||
void InitPowersetMapping() {
|
||||
const auto &meta_data = segmentation_model_.GetModelMetaData();
|
||||
int32_t num_classes = meta_data.num_classes;
|
||||
int32_t powerset_max_classes = meta_data.powerset_max_classes;
|
||||
int32_t num_speakers = meta_data.num_speakers;
|
||||
|
||||
powerset_mapping_ = Matrix2DInt32(num_classes, num_speakers);
|
||||
powerset_mapping_.setZero();
|
||||
|
||||
int32_t k = 1;
|
||||
for (int32_t i = 1; i <= powerset_max_classes; ++i) {
|
||||
if (i == 1) {
|
||||
for (int32_t j = 0; j != num_speakers; ++j, ++k) {
|
||||
powerset_mapping_(k, j) = 1;
|
||||
}
|
||||
} else if (i == 2) {
|
||||
for (int32_t j = 0; j != num_speakers; ++j) {
|
||||
for (int32_t m = j + 1; m < num_speakers; ++m, ++k) {
|
||||
powerset_mapping_(k, j) = 1;
|
||||
powerset_mapping_(k, m) = 1;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"powerset_max_classes = %d is currently not supported!", i);
|
||||
SHERPA_ONNX_EXIT(-1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<Matrix2D> RunSpeakerSegmentationModel(const float *audio,
|
||||
int32_t n) const {
|
||||
std::vector<Matrix2D> ans;
|
||||
|
||||
const auto &meta_data = segmentation_model_.GetModelMetaData();
|
||||
int32_t window_size = meta_data.window_size;
|
||||
int32_t window_shift = meta_data.window_shift;
|
||||
|
||||
if (n <= 0) {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"number of audio samples is %d (<= 0). Please provide a positive "
|
||||
"number",
|
||||
n);
|
||||
return {};
|
||||
}
|
||||
|
||||
if (n <= window_size) {
|
||||
std::vector<float> buf(window_size);
|
||||
// NOTE: buf is zero initialized by default
|
||||
|
||||
std::copy(audio, audio + n, buf.data());
|
||||
|
||||
Matrix2D m = ProcessChunk(buf.data());
|
||||
|
||||
ans.push_back(std::move(m));
|
||||
|
||||
return ans;
|
||||
}
|
||||
|
||||
int32_t num_chunks = (n - window_size) / window_shift + 1;
|
||||
bool has_last_chunk = (n - window_size) % window_shift > 0;
|
||||
|
||||
ans.reserve(num_chunks + has_last_chunk);
|
||||
|
||||
const float *p = audio;
|
||||
|
||||
for (int32_t i = 0; i != num_chunks; ++i, p += window_shift) {
|
||||
Matrix2D m = ProcessChunk(p);
|
||||
|
||||
ans.push_back(std::move(m));
|
||||
}
|
||||
|
||||
if (has_last_chunk) {
|
||||
std::vector<float> buf(window_size);
|
||||
std::copy(p, audio + n, buf.data());
|
||||
|
||||
Matrix2D m = ProcessChunk(buf.data());
|
||||
|
||||
ans.push_back(std::move(m));
|
||||
}
|
||||
|
||||
return ans;
|
||||
}
|
||||
|
||||
Matrix2D ProcessChunk(const float *p) const {
|
||||
const auto &meta_data = segmentation_model_.GetModelMetaData();
|
||||
int32_t window_size = meta_data.window_size;
|
||||
|
||||
auto memory_info =
|
||||
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
||||
|
||||
std::array<int64_t, 3> shape = {1, 1, window_size};
|
||||
|
||||
Ort::Value x =
|
||||
Ort::Value::CreateTensor(memory_info, const_cast<float *>(p),
|
||||
window_size, shape.data(), shape.size());
|
||||
|
||||
Ort::Value out = segmentation_model_.Forward(std::move(x));
|
||||
std::vector<int64_t> out_shape = out.GetTensorTypeAndShapeInfo().GetShape();
|
||||
Matrix2D m(out_shape[1], out_shape[2]);
|
||||
std::copy(out.GetTensorData<float>(), out.GetTensorData<float>() + m.size(),
|
||||
&m(0, 0));
|
||||
return m;
|
||||
}
|
||||
|
||||
Matrix2DInt32 ToMultiLabel(const Matrix2D &m) const {
|
||||
int32_t num_rows = m.rows();
|
||||
Matrix2DInt32 ans(num_rows, powerset_mapping_.cols());
|
||||
|
||||
std::ptrdiff_t col_id;
|
||||
|
||||
for (int32_t i = 0; i != num_rows; ++i) {
|
||||
m.row(i).maxCoeff(&col_id);
|
||||
ans.row(i) = powerset_mapping_.row(col_id);
|
||||
}
|
||||
|
||||
return ans;
|
||||
}
|
||||
|
||||
// See also
|
||||
// https://github.com/pyannote/pyannote-audio/blob/develop/pyannote/audio/pipelines/utils/diarization.py#L122
|
||||
Int32RowVector ComputeSpeakersPerFrame(
|
||||
const std::vector<Matrix2DInt32> &labels) const {
|
||||
const auto &meta_data = segmentation_model_.GetModelMetaData();
|
||||
int32_t window_size = meta_data.window_size;
|
||||
int32_t window_shift = meta_data.window_shift;
|
||||
int32_t receptive_field_shift = meta_data.receptive_field_shift;
|
||||
|
||||
int32_t num_chunks = labels.size();
|
||||
|
||||
int32_t num_frames = (window_size + (num_chunks - 1) * window_shift) /
|
||||
receptive_field_shift +
|
||||
1;
|
||||
|
||||
FloatRowVector count(num_frames);
|
||||
FloatRowVector weight(num_frames);
|
||||
count.setZero();
|
||||
weight.setZero();
|
||||
|
||||
for (int32_t i = 0; i != num_chunks; ++i) {
|
||||
int32_t start =
|
||||
static_cast<float>(i) * window_shift / receptive_field_shift + 0.5;
|
||||
|
||||
auto seq = Eigen::seqN(start, labels[i].rows());
|
||||
|
||||
count(seq).array() += labels[i].rowwise().sum().array().cast<float>();
|
||||
|
||||
weight(seq).array() += 1;
|
||||
}
|
||||
|
||||
return ((count.array() / (weight.array() + 1e-12f)) + 0.5).cast<int32_t>();
|
||||
}
|
||||
|
||||
// ans.first: a list of (chunk_id, speaker_id)
|
||||
// ans.second: a list of list of (start_sample_index, end_sample_index)
|
||||
//
|
||||
// ans.first[i] corresponds to ans.second[i]
|
||||
std::pair<std::vector<Int32Pair>, std::vector<std::vector<Int32Pair>>>
|
||||
GetChunkSpeakerSampleIndexes(const std::vector<Matrix2DInt32> &labels) const {
|
||||
auto new_labels = ExcludeOverlap(labels);
|
||||
|
||||
std::vector<Int32Pair> chunk_speaker_list;
|
||||
std::vector<std::vector<Int32Pair>> samples_index_list;
|
||||
|
||||
const auto &meta_data = segmentation_model_.GetModelMetaData();
|
||||
int32_t window_size = meta_data.window_size;
|
||||
int32_t window_shift = meta_data.window_shift;
|
||||
int32_t receptive_field_shift = meta_data.receptive_field_shift;
|
||||
int32_t num_speakers = meta_data.num_speakers;
|
||||
|
||||
int32_t chunk_index = 0;
|
||||
for (const auto &label : new_labels) {
|
||||
Matrix2DInt32 tmp = label.transpose();
|
||||
// tmp: (num_speakers, num_frames)
|
||||
int32_t num_frames = tmp.cols();
|
||||
|
||||
int32_t sample_offset = chunk_index * window_shift;
|
||||
|
||||
for (int32_t speaker_index = 0; speaker_index != num_speakers;
|
||||
++speaker_index) {
|
||||
auto d = tmp.row(speaker_index);
|
||||
if (d.sum() < 10) {
|
||||
// skip segments less than 10 frames
|
||||
continue;
|
||||
}
|
||||
|
||||
Int32Pair this_chunk_speaker = {chunk_index, speaker_index};
|
||||
std::vector<Int32Pair> this_speaker_samples;
|
||||
|
||||
bool is_active = false;
|
||||
int32_t start_index;
|
||||
|
||||
for (int32_t k = 0; k != num_frames; ++k) {
|
||||
if (d[k] != 0) {
|
||||
if (!is_active) {
|
||||
is_active = true;
|
||||
start_index = k;
|
||||
}
|
||||
} else if (is_active) {
|
||||
is_active = false;
|
||||
|
||||
int32_t start_samples =
|
||||
static_cast<float>(start_index) / num_frames * window_size +
|
||||
sample_offset;
|
||||
int32_t end_samples =
|
||||
static_cast<float>(k) / num_frames * window_size +
|
||||
sample_offset;
|
||||
|
||||
this_speaker_samples.emplace_back(start_samples, end_samples);
|
||||
}
|
||||
}
|
||||
|
||||
if (is_active) {
|
||||
int32_t start_samples =
|
||||
static_cast<float>(start_index) / num_frames * window_size +
|
||||
sample_offset;
|
||||
int32_t end_samples =
|
||||
static_cast<float>(num_frames - 1) / num_frames * window_size +
|
||||
sample_offset;
|
||||
this_speaker_samples.emplace_back(start_samples, end_samples);
|
||||
}
|
||||
|
||||
chunk_speaker_list.push_back(std::move(this_chunk_speaker));
|
||||
samples_index_list.push_back(std::move(this_speaker_samples));
|
||||
} // for (int32_t speaker_index = 0;
|
||||
chunk_index += 1;
|
||||
} // for (const auto &label : new_labels)
|
||||
|
||||
return {chunk_speaker_list, samples_index_list};
|
||||
}
|
||||
|
||||
// If there are multiple speakers at a frame, then this frame is excluded.
|
||||
std::vector<Matrix2DInt32> ExcludeOverlap(
|
||||
const std::vector<Matrix2DInt32> &labels) const {
|
||||
int32_t num_chunks = labels.size();
|
||||
std::vector<Matrix2DInt32> ans;
|
||||
ans.reserve(num_chunks);
|
||||
|
||||
for (const auto &label : labels) {
|
||||
Matrix2DInt32 new_label(label.rows(), label.cols());
|
||||
new_label.setZero();
|
||||
Int32RowVector v = label.rowwise().sum();
|
||||
|
||||
for (int32_t i = 0; i != v.cols(); ++i) {
|
||||
if (v[i] < 2) {
|
||||
new_label.row(i) = label.row(i);
|
||||
}
|
||||
}
|
||||
|
||||
ans.push_back(std::move(new_label));
|
||||
}
|
||||
|
||||
return ans;
|
||||
}
|
||||
|
||||
/**
|
||||
* @param sample_indexes[i] contains the sample segment start and end indexes
|
||||
* for the i-th (chunk, speaker) pair
|
||||
* @return Return a matrix of shape (sample_indexes.size(), embedding_dim)
|
||||
* where ans.row[i] contains the embedding for the
|
||||
* i-th (chunk, speaker) pair
|
||||
*/
|
||||
Matrix2D ComputeEmbeddings(
|
||||
const float *audio, int32_t n,
|
||||
const std::vector<std::vector<Int32Pair>> &sample_indexes,
|
||||
OfflineSpeakerDiarizationProgressCallback callback,
|
||||
void *callback_arg) const {
|
||||
const auto &meta_data = segmentation_model_.GetModelMetaData();
|
||||
int32_t sample_rate = meta_data.sample_rate;
|
||||
Matrix2D ans(sample_indexes.size(), embedding_extractor_.Dim());
|
||||
|
||||
int32_t k = 0;
|
||||
for (const auto &v : sample_indexes) {
|
||||
auto stream = embedding_extractor_.CreateStream();
|
||||
for (const auto &p : v) {
|
||||
int32_t end = (p.second <= n) ? p.second : n;
|
||||
int32_t num_samples = end - p.first;
|
||||
|
||||
if (num_samples > 0) {
|
||||
stream->AcceptWaveform(sample_rate, audio + p.first, num_samples);
|
||||
}
|
||||
}
|
||||
|
||||
stream->InputFinished();
|
||||
if (!embedding_extractor_.IsReady(stream.get())) {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"This segment is too short, which should not happen since we have "
|
||||
"already filtered short segments");
|
||||
SHERPA_ONNX_EXIT(-1);
|
||||
}
|
||||
|
||||
std::vector<float> embedding = embedding_extractor_.Compute(stream.get());
|
||||
|
||||
std::copy(embedding.begin(), embedding.end(), &ans(k, 0));
|
||||
|
||||
k += 1;
|
||||
|
||||
if (callback) {
|
||||
callback(k, ans.rows(), callback_arg);
|
||||
}
|
||||
}
|
||||
|
||||
return ans;
|
||||
}
|
||||
|
||||
std::unordered_map<Int32Pair, int32_t, PairHash> ConvertChunkSpeakerToCluster(
|
||||
const std::vector<Int32Pair> &chunk_speaker_pair,
|
||||
const std::vector<int32_t> &cluster_labels) const {
|
||||
std::unordered_map<Int32Pair, int32_t, PairHash> ans;
|
||||
|
||||
int32_t k = 0;
|
||||
for (const auto &p : chunk_speaker_pair) {
|
||||
ans[p] = cluster_labels[k];
|
||||
k += 1;
|
||||
}
|
||||
|
||||
return ans;
|
||||
}
|
||||
|
||||
std::vector<Matrix2DInt32> ReLabel(
|
||||
const std::vector<Matrix2DInt32> &labels, int32_t max_cluster_index,
|
||||
std::unordered_map<Int32Pair, int32_t, PairHash> chunk_speaker_to_cluster)
|
||||
const {
|
||||
std::vector<Matrix2DInt32> new_labels;
|
||||
new_labels.reserve(labels.size());
|
||||
|
||||
int32_t chunk_index = 0;
|
||||
for (const auto &label : labels) {
|
||||
Matrix2DInt32 new_label(label.rows(), max_cluster_index + 1);
|
||||
new_label.setZero();
|
||||
|
||||
Matrix2DInt32 t = label.transpose();
|
||||
// t: (num_speakers, num_frames)
|
||||
|
||||
for (int32_t speaker_index = 0; speaker_index != t.rows();
|
||||
++speaker_index) {
|
||||
if (chunk_speaker_to_cluster.count({chunk_index, speaker_index}) == 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
int32_t new_speaker_index =
|
||||
chunk_speaker_to_cluster.at({chunk_index, speaker_index});
|
||||
|
||||
for (int32_t k = 0; k != t.cols(); ++k) {
|
||||
if (t(speaker_index, k) == 1) {
|
||||
new_label(k, new_speaker_index) = 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
new_labels.push_back(std::move(new_label));
|
||||
|
||||
chunk_index += 1;
|
||||
}
|
||||
|
||||
return new_labels;
|
||||
}
|
||||
|
||||
Matrix2DInt32 ComputeSpeakerCount(const std::vector<Matrix2DInt32> &labels,
|
||||
int32_t num_samples) const {
|
||||
const auto &meta_data = segmentation_model_.GetModelMetaData();
|
||||
int32_t window_size = meta_data.window_size;
|
||||
int32_t window_shift = meta_data.window_shift;
|
||||
int32_t receptive_field_shift = meta_data.receptive_field_shift;
|
||||
|
||||
int32_t num_chunks = labels.size();
|
||||
|
||||
int32_t num_frames = (window_size + (num_chunks - 1) * window_shift) /
|
||||
receptive_field_shift +
|
||||
1;
|
||||
|
||||
Matrix2DInt32 count(num_frames, labels[0].cols());
|
||||
count.setZero();
|
||||
|
||||
for (int32_t i = 0; i != num_chunks; ++i) {
|
||||
int32_t start =
|
||||
static_cast<float>(i) * window_shift / receptive_field_shift + 0.5;
|
||||
|
||||
auto seq = Eigen::seqN(start, labels[i].rows());
|
||||
|
||||
count(seq, Eigen::all).array() += labels[i].array();
|
||||
}
|
||||
|
||||
bool has_last_chunk = (num_samples - window_size) % window_shift > 0;
|
||||
|
||||
if (has_last_chunk) {
|
||||
return count;
|
||||
}
|
||||
|
||||
int32_t last_frame = num_samples / receptive_field_shift;
|
||||
return count(Eigen::seq(0, last_frame), Eigen::all);
|
||||
}
|
||||
|
||||
Matrix2DInt32 FinalizeLabels(const Matrix2DInt32 &count,
|
||||
const Int32RowVector &speakers_per_frame) const {
|
||||
int32_t num_rows = count.rows();
|
||||
int32_t num_cols = count.cols();
|
||||
|
||||
Matrix2DInt32 ans(num_rows, num_cols);
|
||||
ans.setZero();
|
||||
|
||||
for (int32_t i = 0; i != num_rows; ++i) {
|
||||
int32_t k = speakers_per_frame[i];
|
||||
if (k == 0) {
|
||||
continue;
|
||||
}
|
||||
auto top_k = TopkIndex(&count(i, 0), num_cols, k);
|
||||
|
||||
for (int32_t m : top_k) {
|
||||
ans(i, m) = 1;
|
||||
}
|
||||
}
|
||||
|
||||
return ans;
|
||||
}
|
||||
|
||||
OfflineSpeakerDiarizationResult ComputeResult(
|
||||
const Matrix2DInt32 &final_labels) const {
|
||||
Matrix2DInt32 final_labels_t = final_labels.transpose();
|
||||
int32_t num_speakers = final_labels_t.rows();
|
||||
int32_t num_frames = final_labels_t.cols();
|
||||
|
||||
const auto &meta_data = segmentation_model_.GetModelMetaData();
|
||||
int32_t window_size = meta_data.window_size;
|
||||
int32_t window_shift = meta_data.window_shift;
|
||||
int32_t receptive_field_shift = meta_data.receptive_field_shift;
|
||||
int32_t receptive_field_size = meta_data.receptive_field_size;
|
||||
int32_t sample_rate = meta_data.sample_rate;
|
||||
|
||||
float scale = static_cast<float>(receptive_field_shift) / sample_rate;
|
||||
float scale_offset = 0.5 * receptive_field_size / sample_rate;
|
||||
|
||||
OfflineSpeakerDiarizationResult ans;
|
||||
|
||||
for (int32_t speaker_index = 0; speaker_index != num_speakers;
|
||||
++speaker_index) {
|
||||
std::vector<OfflineSpeakerDiarizationSegment> this_speaker;
|
||||
|
||||
bool is_active = final_labels_t(speaker_index, 0) > 0;
|
||||
int32_t start_index = is_active ? 0 : -1;
|
||||
|
||||
for (int32_t frame_index = 1; frame_index != num_frames; ++frame_index) {
|
||||
if (is_active) {
|
||||
if (final_labels_t(speaker_index, frame_index) == 0) {
|
||||
float start_time = start_index * scale + scale_offset;
|
||||
float end_time = frame_index * scale + scale_offset;
|
||||
|
||||
OfflineSpeakerDiarizationSegment segment(start_time, end_time,
|
||||
speaker_index);
|
||||
this_speaker.push_back(segment);
|
||||
|
||||
is_active = false;
|
||||
}
|
||||
} else if (final_labels_t(speaker_index, frame_index) == 1) {
|
||||
is_active = true;
|
||||
start_index = frame_index;
|
||||
}
|
||||
}
|
||||
|
||||
if (is_active) {
|
||||
float start_time = start_index * scale + scale_offset;
|
||||
float end_time = (num_frames - 1) * scale + scale_offset;
|
||||
|
||||
OfflineSpeakerDiarizationSegment segment(start_time, end_time,
|
||||
speaker_index);
|
||||
this_speaker.push_back(segment);
|
||||
}
|
||||
|
||||
// merge segments if the gap between them is less than min_duration_off
|
||||
MergeSegments(&this_speaker);
|
||||
|
||||
for (const auto &seg : this_speaker) {
|
||||
if (seg.Duration() > config_.min_duration_on) {
|
||||
ans.Add(seg);
|
||||
}
|
||||
}
|
||||
} // for (int32_t speaker_index = 0; speaker_index != num_speakers;
|
||||
|
||||
return ans;
|
||||
}
|
||||
|
||||
void MergeSegments(
|
||||
std::vector<OfflineSpeakerDiarizationSegment> *segments) const {
|
||||
float min_duration_off = config_.min_duration_off;
|
||||
bool changed = true;
|
||||
while (changed) {
|
||||
changed = false;
|
||||
for (int32_t i = 0; i < static_cast<int32_t>(segments->size()) - 1; ++i) {
|
||||
auto s = (*segments)[i].Merge((*segments)[i + 1], min_duration_off);
|
||||
if (s) {
|
||||
(*segments)[i] = s.value();
|
||||
segments->erase(segments->begin() + i + 1);
|
||||
|
||||
changed = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
OfflineSpeakerDiarizationConfig config_;
|
||||
OfflineSpeakerSegmentationPyannoteModel segmentation_model_;
|
||||
SpeakerEmbeddingExtractor embedding_extractor_;
|
||||
FastClustering clustering_;
|
||||
Matrix2DInt32 powerset_mapping_;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
#endif // SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_DIARIZATION_PYANNOTE_IMPL_H_
|
||||
110
sherpa-onnx/csrc/offline-speaker-diarization-result.cc
Normal file
110
sherpa-onnx/csrc/offline-speaker-diarization-result.cc
Normal file
@@ -0,0 +1,110 @@
|
||||
// sherpa-onnx/csrc/offline-speaker-diarization-result.cc
|
||||
//
|
||||
// Copyright (c) 2024 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-speaker-diarization-result.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <unordered_set>
|
||||
#include <utility>
|
||||
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
OfflineSpeakerDiarizationSegment::OfflineSpeakerDiarizationSegment(
|
||||
float start, float end, int32_t speaker, const std::string &text /*= {}*/) {
|
||||
if (start > end) {
|
||||
SHERPA_ONNX_LOGE("start %.3f should be less than end %.3f", start, end);
|
||||
SHERPA_ONNX_EXIT(-1);
|
||||
}
|
||||
|
||||
start_ = start;
|
||||
end_ = end;
|
||||
speaker_ = speaker;
|
||||
text_ = text;
|
||||
}
|
||||
|
||||
std::optional<OfflineSpeakerDiarizationSegment>
|
||||
OfflineSpeakerDiarizationSegment::Merge(
|
||||
const OfflineSpeakerDiarizationSegment &other, float gap) const {
|
||||
if (other.speaker_ != speaker_) {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"The two segments should have the same speaker. this->speaker: %d, "
|
||||
"other.speaker: %d",
|
||||
speaker_, other.speaker_);
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
if (end_ < other.start_ && end_ + gap >= other.start_) {
|
||||
return OfflineSpeakerDiarizationSegment(start_, other.end_, speaker_);
|
||||
} else if (other.end_ < start_ && other.end_ + gap >= start_) {
|
||||
return OfflineSpeakerDiarizationSegment(other.start_, end_, speaker_);
|
||||
} else {
|
||||
return std::nullopt;
|
||||
}
|
||||
}
|
||||
|
||||
std::string OfflineSpeakerDiarizationSegment::ToString() const {
|
||||
char s[128];
|
||||
snprintf(s, sizeof(s), "%.3f -- %.3f speaker_%02d", start_, end_, speaker_);
|
||||
|
||||
std::ostringstream os;
|
||||
os << s;
|
||||
|
||||
if (!text_.empty()) {
|
||||
os << " " << text_;
|
||||
}
|
||||
|
||||
return os.str();
|
||||
}
|
||||
|
||||
void OfflineSpeakerDiarizationResult::Add(
|
||||
const OfflineSpeakerDiarizationSegment &segment) {
|
||||
segments_.push_back(segment);
|
||||
}
|
||||
|
||||
int32_t OfflineSpeakerDiarizationResult::NumSpeakers() const {
|
||||
std::unordered_set<int32_t> count;
|
||||
for (const auto &s : segments_) {
|
||||
count.insert(s.Speaker());
|
||||
}
|
||||
|
||||
return count.size();
|
||||
}
|
||||
|
||||
int32_t OfflineSpeakerDiarizationResult::NumSegments() const {
|
||||
return segments_.size();
|
||||
}
|
||||
|
||||
// Return a list of segments sorted by segment.start time
|
||||
std::vector<OfflineSpeakerDiarizationSegment>
|
||||
OfflineSpeakerDiarizationResult::SortByStartTime() const {
|
||||
auto ans = segments_;
|
||||
std::sort(ans.begin(), ans.end(), [](const auto &a, const auto &b) {
|
||||
return (a.Start() < b.Start()) ||
|
||||
((a.Start() == b.Start()) && (a.Speaker() < b.Speaker()));
|
||||
});
|
||||
|
||||
return ans;
|
||||
}
|
||||
|
||||
std::vector<std::vector<OfflineSpeakerDiarizationSegment>>
|
||||
OfflineSpeakerDiarizationResult::SortBySpeaker() const {
|
||||
auto tmp = segments_;
|
||||
std::sort(tmp.begin(), tmp.end(), [](const auto &a, const auto &b) {
|
||||
return (a.Speaker() < b.Speaker()) ||
|
||||
((a.Speaker() == b.Speaker()) && (a.Start() < b.Start()));
|
||||
});
|
||||
|
||||
std::vector<std::vector<OfflineSpeakerDiarizationSegment>> ans(NumSpeakers());
|
||||
for (auto &s : tmp) {
|
||||
ans[s.Speaker()].push_back(std::move(s));
|
||||
}
|
||||
|
||||
return ans;
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
65
sherpa-onnx/csrc/offline-speaker-diarization-result.h
Normal file
65
sherpa-onnx/csrc/offline-speaker-diarization-result.h
Normal file
@@ -0,0 +1,65 @@
|
||||
// sherpa-onnx/csrc/offline-speaker-diarization-result.h
|
||||
//
|
||||
// Copyright (c) 2024 Xiaomi Corporation
|
||||
|
||||
#ifndef SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_DIARIZATION_RESULT_H_
|
||||
#define SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_DIARIZATION_RESULT_H_
|
||||
|
||||
#include <cstdint>
|
||||
#include <optional>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
class OfflineSpeakerDiarizationSegment {
|
||||
public:
|
||||
OfflineSpeakerDiarizationSegment(float start, float end, int32_t speaker,
|
||||
const std::string &text = {});
|
||||
|
||||
// If the gap between the two segments is less than the given gap, then we
|
||||
// merge them and return a new segment. Otherwise, it returns null.
|
||||
std::optional<OfflineSpeakerDiarizationSegment> Merge(
|
||||
const OfflineSpeakerDiarizationSegment &other, float gap) const;
|
||||
|
||||
float Start() const { return start_; }
|
||||
float End() const { return end_; }
|
||||
int32_t Speaker() const { return speaker_; }
|
||||
const std::string &Text() const { return text_; }
|
||||
float Duration() const { return end_ - start_; }
|
||||
|
||||
std::string ToString() const;
|
||||
|
||||
private:
|
||||
float start_; // in seconds
|
||||
float end_; // in seconds
|
||||
int32_t speaker_; // ID of the speaker, starting from 0
|
||||
std::string text_; // If not empty, it contains the speech recognition result
|
||||
// of this segment
|
||||
};
|
||||
|
||||
class OfflineSpeakerDiarizationResult {
|
||||
public:
|
||||
// Add a new segment
|
||||
void Add(const OfflineSpeakerDiarizationSegment &segment);
|
||||
|
||||
// Number of distinct speakers contained in this object at this point
|
||||
int32_t NumSpeakers() const;
|
||||
|
||||
int32_t NumSegments() const;
|
||||
|
||||
// Return a list of segments sorted by segment.start time
|
||||
std::vector<OfflineSpeakerDiarizationSegment> SortByStartTime() const;
|
||||
|
||||
// ans.size() == NumSpeakers().
|
||||
// ans[i] is for speaker_i and is sorted by start time
|
||||
std::vector<std::vector<OfflineSpeakerDiarizationSegment>> SortBySpeaker()
|
||||
const;
|
||||
|
||||
public:
|
||||
std::vector<OfflineSpeakerDiarizationSegment> segments_;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_DIARIZATION_RESULT_H_
|
||||
79
sherpa-onnx/csrc/offline-speaker-diarization.cc
Normal file
79
sherpa-onnx/csrc/offline-speaker-diarization.cc
Normal file
@@ -0,0 +1,79 @@
|
||||
// sherpa-onnx/csrc/offline-speaker-diarization.cc
|
||||
//
|
||||
// Copyright (c) 2024 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-speaker-diarization.h"
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-speaker-diarization-impl.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
void OfflineSpeakerDiarizationConfig::Register(ParseOptions *po) {
|
||||
ParseOptions po_segmentation("segmentation", po);
|
||||
segmentation.Register(&po_segmentation);
|
||||
|
||||
ParseOptions po_embedding("embedding", po);
|
||||
embedding.Register(&po_embedding);
|
||||
|
||||
ParseOptions po_clustering("clustering", po);
|
||||
clustering.Register(&po_clustering);
|
||||
|
||||
po->Register("min-duration-on", &min_duration_on,
|
||||
"if a segment is less than this value, then it is discarded. "
|
||||
"Set it to 0 so that no segment is discarded");
|
||||
|
||||
po->Register("min-duration-off", &min_duration_off,
|
||||
"if the gap between to segments of the same speaker is less "
|
||||
"than this value, then these two segments are merged into a "
|
||||
"single segment. We do it recursively.");
|
||||
}
|
||||
|
||||
bool OfflineSpeakerDiarizationConfig::Validate() const {
|
||||
if (!segmentation.Validate()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!embedding.Validate()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!clustering.Validate()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
std::string OfflineSpeakerDiarizationConfig::ToString() const {
|
||||
std::ostringstream os;
|
||||
|
||||
os << "OfflineSpeakerDiarizationConfig(";
|
||||
os << "segmentation=" << segmentation.ToString() << ", ";
|
||||
os << "embedding=" << embedding.ToString() << ", ";
|
||||
os << "clustering=" << clustering.ToString() << ", ";
|
||||
os << "min_duration_on=" << min_duration_on << ", ";
|
||||
os << "min_duration_off=" << min_duration_off << ")";
|
||||
|
||||
return os.str();
|
||||
}
|
||||
|
||||
OfflineSpeakerDiarization::OfflineSpeakerDiarization(
|
||||
const OfflineSpeakerDiarizationConfig &config)
|
||||
: impl_(OfflineSpeakerDiarizationImpl::Create(config)) {}
|
||||
|
||||
OfflineSpeakerDiarization::~OfflineSpeakerDiarization() = default;
|
||||
|
||||
int32_t OfflineSpeakerDiarization::SampleRate() const {
|
||||
return impl_->SampleRate();
|
||||
}
|
||||
|
||||
OfflineSpeakerDiarizationResult OfflineSpeakerDiarization::Process(
|
||||
const float *audio, int32_t n,
|
||||
OfflineSpeakerDiarizationProgressCallback callback /*= nullptr*/,
|
||||
void *callback_arg /*= nullptr*/) const {
|
||||
return impl_->Process(audio, n, callback, callback_arg);
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
73
sherpa-onnx/csrc/offline-speaker-diarization.h
Normal file
73
sherpa-onnx/csrc/offline-speaker-diarization.h
Normal file
@@ -0,0 +1,73 @@
|
||||
// sherpa-onnx/csrc/offline-speaker-diarization.h
|
||||
//
|
||||
// Copyright (c) 2024 Xiaomi Corporation
|
||||
|
||||
#ifndef SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_DIARIZATION_H_
|
||||
#define SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_DIARIZATION_H_
|
||||
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "sherpa-onnx/csrc/fast-clustering-config.h"
|
||||
#include "sherpa-onnx/csrc/offline-speaker-diarization-result.h"
|
||||
#include "sherpa-onnx/csrc/offline-speaker-segmentation-model-config.h"
|
||||
#include "sherpa-onnx/csrc/speaker-embedding-extractor.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
struct OfflineSpeakerDiarizationConfig {
|
||||
OfflineSpeakerSegmentationModelConfig segmentation;
|
||||
SpeakerEmbeddingExtractorConfig embedding;
|
||||
FastClusteringConfig clustering;
|
||||
|
||||
// if a segment is less than this value, then it is discarded
|
||||
float min_duration_on = 0.3; // in seconds
|
||||
|
||||
// if the gap between to segments of the same speaker is less than this value,
|
||||
// then these two segments are merged into a single segment.
|
||||
// We do this recursively.
|
||||
float min_duration_off = 0.5; // in seconds
|
||||
|
||||
OfflineSpeakerDiarizationConfig() = default;
|
||||
|
||||
OfflineSpeakerDiarizationConfig(
|
||||
const OfflineSpeakerSegmentationModelConfig &segmentation,
|
||||
const SpeakerEmbeddingExtractorConfig &embedding,
|
||||
const FastClusteringConfig &clustering)
|
||||
: segmentation(segmentation),
|
||||
embedding(embedding),
|
||||
clustering(clustering) {}
|
||||
|
||||
void Register(ParseOptions *po);
|
||||
bool Validate() const;
|
||||
std::string ToString() const;
|
||||
};
|
||||
|
||||
class OfflineSpeakerDiarizationImpl;
|
||||
|
||||
using OfflineSpeakerDiarizationProgressCallback = std::function<int32_t(
|
||||
int32_t processed_chunks, int32_t num_chunks, void *arg)>;
|
||||
|
||||
class OfflineSpeakerDiarization {
|
||||
public:
|
||||
explicit OfflineSpeakerDiarization(
|
||||
const OfflineSpeakerDiarizationConfig &config);
|
||||
|
||||
~OfflineSpeakerDiarization();
|
||||
|
||||
// Expected sample rate of the input audio samples
|
||||
int32_t SampleRate() const;
|
||||
|
||||
OfflineSpeakerDiarizationResult Process(
|
||||
const float *audio, int32_t n,
|
||||
OfflineSpeakerDiarizationProgressCallback callback = nullptr,
|
||||
void *callback_arg = nullptr) const;
|
||||
|
||||
private:
|
||||
std::unique_ptr<OfflineSpeakerDiarizationImpl> impl_;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_DIARIZATION_H_
|
||||
@@ -0,0 +1,57 @@
|
||||
// sherpa-onnx/csrc/offline-speaker-segmentation-model-config.cc
|
||||
//
|
||||
// Copyright (c) 2024 Xiaomi Corporation
|
||||
#include "sherpa-onnx/csrc/offline-speaker-segmentation-model-config.h"
|
||||
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
void OfflineSpeakerSegmentationModelConfig::Register(ParseOptions *po) {
|
||||
pyannote.Register(po);
|
||||
|
||||
po->Register("num-threads", &num_threads,
|
||||
"Number of threads to run the neural network");
|
||||
|
||||
po->Register("debug", &debug,
|
||||
"true to print model information while loading it.");
|
||||
|
||||
po->Register("provider", &provider,
|
||||
"Specify a provider to use: cpu, cuda, coreml");
|
||||
}
|
||||
|
||||
bool OfflineSpeakerSegmentationModelConfig::Validate() const {
|
||||
if (num_threads < 1) {
|
||||
SHERPA_ONNX_LOGE("num_threads should be > 0. Given %d", num_threads);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!pyannote.model.empty()) {
|
||||
return pyannote.Validate();
|
||||
}
|
||||
|
||||
if (pyannote.model.empty()) {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"You have to provide at least one speaker segmentation model");
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
std::string OfflineSpeakerSegmentationModelConfig::ToString() const {
|
||||
std::ostringstream os;
|
||||
|
||||
os << "OfflineSpeakerSegmentationModelConfig(";
|
||||
os << "pyannote=" << pyannote.ToString() << ", ";
|
||||
os << "num_threads=" << num_threads << ", ";
|
||||
os << "debug=" << (debug ? "True" : "False") << ", ";
|
||||
os << "provider=\"" << provider << "\")";
|
||||
|
||||
return os.str();
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
40
sherpa-onnx/csrc/offline-speaker-segmentation-model-config.h
Normal file
40
sherpa-onnx/csrc/offline-speaker-segmentation-model-config.h
Normal file
@@ -0,0 +1,40 @@
|
||||
// sherpa-onnx/csrc/offline-speaker-segmentation-model-config.h
|
||||
//
|
||||
// Copyright (c) 2024 Xiaomi Corporation
|
||||
#ifndef SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_SEGMENTATION_MODEL_CONFIG_H_
|
||||
#define SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_SEGMENTATION_MODEL_CONFIG_H_
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model-config.h"
|
||||
#include "sherpa-onnx/csrc/parse-options.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
struct OfflineSpeakerSegmentationModelConfig {
|
||||
OfflineSpeakerSegmentationPyannoteModelConfig pyannote;
|
||||
|
||||
int32_t num_threads = 1;
|
||||
bool debug = false;
|
||||
std::string provider = "cpu";
|
||||
|
||||
OfflineSpeakerSegmentationModelConfig() = default;
|
||||
|
||||
explicit OfflineSpeakerSegmentationModelConfig(
|
||||
const OfflineSpeakerSegmentationPyannoteModelConfig &pyannote,
|
||||
int32_t num_threads, bool debug, const std::string &provider)
|
||||
: pyannote(pyannote),
|
||||
num_threads(num_threads),
|
||||
debug(debug),
|
||||
provider(provider) {}
|
||||
|
||||
void Register(ParseOptions *po);
|
||||
|
||||
bool Validate() const;
|
||||
|
||||
std::string ToString() const;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_SEGMENTATION_MODEL_CONFIG_H_
|
||||
@@ -0,0 +1,38 @@
|
||||
// sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model-config.cc
|
||||
//
|
||||
// Copyright (c) 2024 Xiaomi Corporation
|
||||
#include "sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model-config.h"
|
||||
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
|
||||
#include "sherpa-onnx/csrc/file-utils.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
void OfflineSpeakerSegmentationPyannoteModelConfig::Register(ParseOptions *po) {
|
||||
po->Register("pyannote-model", &model,
|
||||
"Path to model.onnx of the Pyannote segmentation model.");
|
||||
}
|
||||
|
||||
bool OfflineSpeakerSegmentationPyannoteModelConfig::Validate() const {
|
||||
if (!FileExists(model)) {
|
||||
SHERPA_ONNX_LOGE("Pyannote segmentation model: '%s' does not exist",
|
||||
model.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
std::string OfflineSpeakerSegmentationPyannoteModelConfig::ToString() const {
|
||||
std::ostringstream os;
|
||||
|
||||
os << "OfflineSpeakerSegmentationPyannoteModelConfig(";
|
||||
os << "model=\"" << model << "\")";
|
||||
|
||||
return os.str();
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
@@ -0,0 +1,30 @@
|
||||
// sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model-config.h
|
||||
//
|
||||
// Copyright (c) 2024 Xiaomi Corporation
|
||||
|
||||
#ifndef SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_SEGMENTATION_PYANNOTE_MODEL_CONFIG_H_
|
||||
#define SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_SEGMENTATION_PYANNOTE_MODEL_CONFIG_H_
|
||||
#include <string>
|
||||
|
||||
#include "sherpa-onnx/csrc/parse-options.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
struct OfflineSpeakerSegmentationPyannoteModelConfig {
|
||||
std::string model;
|
||||
|
||||
OfflineSpeakerSegmentationPyannoteModelConfig() = default;
|
||||
|
||||
explicit OfflineSpeakerSegmentationPyannoteModelConfig(
|
||||
const std::string &model)
|
||||
: model(model) {}
|
||||
|
||||
void Register(ParseOptions *po);
|
||||
bool Validate() const;
|
||||
|
||||
std::string ToString() const;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_SEGMENTATION_PYANNOTE_MODEL_CONFIG_H_
|
||||
@@ -0,0 +1,29 @@
|
||||
// sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model-meta-data.h
|
||||
//
|
||||
// Copyright (c) 2024 Xiaomi Corporation
|
||||
|
||||
#ifndef SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_SEGMENTATION_PYANNOTE_MODEL_META_DATA_H_
|
||||
#define SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_SEGMENTATION_PYANNOTE_MODEL_META_DATA_H_
|
||||
|
||||
#include <cstdint>
|
||||
#include <string>
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
// If you are not sure what each field means, please
|
||||
// have a look of the Python file in the model directory that
|
||||
// you have downloaded.
|
||||
struct OfflineSpeakerSegmentationPyannoteModelMetaData {
|
||||
int32_t sample_rate = 0;
|
||||
int32_t window_size = 0; // in samples
|
||||
int32_t window_shift = 0; // in samples
|
||||
int32_t receptive_field_size = 0; // in samples
|
||||
int32_t receptive_field_shift = 0; // in samples
|
||||
int32_t num_speakers = 0;
|
||||
int32_t powerset_max_classes = 0;
|
||||
int32_t num_classes = 0;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_SEGMENTATION_PYANNOTE_MODEL_META_DATA_H_
|
||||
108
sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model.cc
Normal file
108
sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model.cc
Normal file
@@ -0,0 +1,108 @@
|
||||
// sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model.cc
|
||||
//
|
||||
// Copyright (c) 2024 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model.h"
|
||||
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
#include "sherpa-onnx/csrc/session.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
class OfflineSpeakerSegmentationPyannoteModel::Impl {
|
||||
public:
|
||||
explicit Impl(const OfflineSpeakerSegmentationModelConfig &config)
|
||||
: config_(config),
|
||||
env_(ORT_LOGGING_LEVEL_ERROR),
|
||||
sess_opts_(GetSessionOptions(config)),
|
||||
allocator_{} {
|
||||
auto buf = ReadFile(config_.pyannote.model);
|
||||
Init(buf.data(), buf.size());
|
||||
}
|
||||
|
||||
const OfflineSpeakerSegmentationPyannoteModelMetaData &GetModelMetaData()
|
||||
const {
|
||||
return meta_data_;
|
||||
}
|
||||
|
||||
Ort::Value Forward(Ort::Value x) {
|
||||
auto out = sess_->Run({}, input_names_ptr_.data(), &x, 1,
|
||||
output_names_ptr_.data(), output_names_ptr_.size());
|
||||
|
||||
return std::move(out[0]);
|
||||
}
|
||||
|
||||
private:
|
||||
void Init(void *model_data, size_t model_data_length) {
|
||||
sess_ = std::make_unique<Ort::Session>(env_, model_data, model_data_length,
|
||||
sess_opts_);
|
||||
|
||||
GetInputNames(sess_.get(), &input_names_, &input_names_ptr_);
|
||||
|
||||
GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_);
|
||||
|
||||
// get meta data
|
||||
Ort::ModelMetadata meta_data = sess_->GetModelMetadata();
|
||||
if (config_.debug) {
|
||||
std::ostringstream os;
|
||||
PrintModelMetadata(os, meta_data);
|
||||
SHERPA_ONNX_LOGE("%s\n", os.str().c_str());
|
||||
}
|
||||
|
||||
Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
|
||||
SHERPA_ONNX_READ_META_DATA(meta_data_.sample_rate, "sample_rate");
|
||||
SHERPA_ONNX_READ_META_DATA(meta_data_.window_size, "window_size");
|
||||
|
||||
meta_data_.window_shift =
|
||||
static_cast<int32_t>(0.1 * meta_data_.window_size);
|
||||
|
||||
SHERPA_ONNX_READ_META_DATA(meta_data_.receptive_field_size,
|
||||
"receptive_field_size");
|
||||
SHERPA_ONNX_READ_META_DATA(meta_data_.receptive_field_shift,
|
||||
"receptive_field_shift");
|
||||
SHERPA_ONNX_READ_META_DATA(meta_data_.num_speakers, "num_speakers");
|
||||
SHERPA_ONNX_READ_META_DATA(meta_data_.powerset_max_classes,
|
||||
"powerset_max_classes");
|
||||
SHERPA_ONNX_READ_META_DATA(meta_data_.num_classes, "num_classes");
|
||||
}
|
||||
|
||||
private:
|
||||
OfflineSpeakerSegmentationModelConfig config_;
|
||||
Ort::Env env_;
|
||||
Ort::SessionOptions sess_opts_;
|
||||
Ort::AllocatorWithDefaultOptions allocator_;
|
||||
|
||||
std::unique_ptr<Ort::Session> sess_;
|
||||
|
||||
std::vector<std::string> input_names_;
|
||||
std::vector<const char *> input_names_ptr_;
|
||||
|
||||
std::vector<std::string> output_names_;
|
||||
std::vector<const char *> output_names_ptr_;
|
||||
|
||||
OfflineSpeakerSegmentationPyannoteModelMetaData meta_data_;
|
||||
};
|
||||
|
||||
OfflineSpeakerSegmentationPyannoteModel::
|
||||
OfflineSpeakerSegmentationPyannoteModel(
|
||||
const OfflineSpeakerSegmentationModelConfig &config)
|
||||
: impl_(std::make_unique<Impl>(config)) {}
|
||||
|
||||
OfflineSpeakerSegmentationPyannoteModel::
|
||||
~OfflineSpeakerSegmentationPyannoteModel() = default;
|
||||
|
||||
const OfflineSpeakerSegmentationPyannoteModelMetaData &
|
||||
OfflineSpeakerSegmentationPyannoteModel::GetModelMetaData() const {
|
||||
return impl_->GetModelMetaData();
|
||||
}
|
||||
|
||||
Ort::Value OfflineSpeakerSegmentationPyannoteModel::Forward(
|
||||
Ort::Value x) const {
|
||||
return impl_->Forward(std::move(x));
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
@@ -0,0 +1,40 @@
|
||||
// sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model.h
|
||||
//
|
||||
// Copyright (c) 2024 Xiaomi Corporation
|
||||
#ifndef SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_SEGMENTATION_PYANNOTE_MODEL_H_
|
||||
#define SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_SEGMENTATION_PYANNOTE_MODEL_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||
#include "sherpa-onnx/csrc/offline-speaker-segmentation-model-config.h"
|
||||
#include "sherpa-onnx/csrc/offline-speaker-segmentation-pyannote-model-meta-data.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
class OfflineSpeakerSegmentationPyannoteModel {
|
||||
public:
|
||||
explicit OfflineSpeakerSegmentationPyannoteModel(
|
||||
const OfflineSpeakerSegmentationModelConfig &config);
|
||||
|
||||
~OfflineSpeakerSegmentationPyannoteModel();
|
||||
|
||||
const OfflineSpeakerSegmentationPyannoteModelMetaData &GetModelMetaData()
|
||||
const;
|
||||
|
||||
/**
|
||||
* @param x A 3-D float tensor of shape (batch_size, 1, num_samples)
|
||||
* @return Return a float tensor of
|
||||
* shape (batch_size, num_frames, num_speakers). Note that
|
||||
* num_speakers here uses powerset encoding.
|
||||
*/
|
||||
Ort::Value Forward(Ort::Value x) const;
|
||||
|
||||
private:
|
||||
class Impl;
|
||||
std::unique_ptr<Impl> impl_;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_OFFLINE_SPEAKER_SEGMENTATION_PYANNOTE_MODEL_H_
|
||||
@@ -61,8 +61,10 @@ void TensorrtConfig::Register(ParseOptions *po) {
|
||||
|
||||
bool TensorrtConfig::Validate() const {
|
||||
if (trt_max_workspace_size < 0) {
|
||||
SHERPA_ONNX_LOGE("trt_max_workspace_size: %ld is not valid.",
|
||||
trt_max_workspace_size);
|
||||
std::ostringstream os;
|
||||
os << "trt_max_workspace_size: " << trt_max_workspace_size
|
||||
<< " is not valid.";
|
||||
SHERPA_ONNX_LOGE("%s", os.str().c_str());
|
||||
return false;
|
||||
}
|
||||
if (trt_max_partition_iterations < 0) {
|
||||
|
||||
@@ -35,9 +35,9 @@ static void OrtStatusFailure(OrtStatus *status, const char *s) {
|
||||
api.ReleaseStatus(status);
|
||||
}
|
||||
|
||||
static Ort::SessionOptions GetSessionOptionsImpl(
|
||||
Ort::SessionOptions GetSessionOptionsImpl(
|
||||
int32_t num_threads, const std::string &provider_str,
|
||||
const ProviderConfig *provider_config = nullptr) {
|
||||
const ProviderConfig *provider_config /*= nullptr*/) {
|
||||
Provider p = StringToProvider(provider_str);
|
||||
|
||||
Ort::SessionOptions sess_opts;
|
||||
@@ -259,10 +259,6 @@ Ort::SessionOptions GetSessionOptions(const OnlineModelConfig &config,
|
||||
&config.provider_config);
|
||||
}
|
||||
|
||||
Ort::SessionOptions GetSessionOptions(const OfflineModelConfig &config) {
|
||||
return GetSessionOptionsImpl(config.num_threads, config.provider);
|
||||
}
|
||||
|
||||
Ort::SessionOptions GetSessionOptions(const OfflineLMConfig &config) {
|
||||
return GetSessionOptionsImpl(config.lm_num_threads, config.lm_provider);
|
||||
}
|
||||
@@ -271,38 +267,4 @@ Ort::SessionOptions GetSessionOptions(const OnlineLMConfig &config) {
|
||||
return GetSessionOptionsImpl(config.lm_num_threads, config.lm_provider);
|
||||
}
|
||||
|
||||
Ort::SessionOptions GetSessionOptions(const VadModelConfig &config) {
|
||||
return GetSessionOptionsImpl(config.num_threads, config.provider);
|
||||
}
|
||||
|
||||
#if SHERPA_ONNX_ENABLE_TTS
|
||||
Ort::SessionOptions GetSessionOptions(const OfflineTtsModelConfig &config) {
|
||||
return GetSessionOptionsImpl(config.num_threads, config.provider);
|
||||
}
|
||||
#endif
|
||||
|
||||
Ort::SessionOptions GetSessionOptions(
|
||||
const SpeakerEmbeddingExtractorConfig &config) {
|
||||
return GetSessionOptionsImpl(config.num_threads, config.provider);
|
||||
}
|
||||
|
||||
Ort::SessionOptions GetSessionOptions(
|
||||
const SpokenLanguageIdentificationConfig &config) {
|
||||
return GetSessionOptionsImpl(config.num_threads, config.provider);
|
||||
}
|
||||
|
||||
Ort::SessionOptions GetSessionOptions(const AudioTaggingModelConfig &config) {
|
||||
return GetSessionOptionsImpl(config.num_threads, config.provider);
|
||||
}
|
||||
|
||||
Ort::SessionOptions GetSessionOptions(
|
||||
const OfflinePunctuationModelConfig &config) {
|
||||
return GetSessionOptionsImpl(config.num_threads, config.provider);
|
||||
}
|
||||
|
||||
Ort::SessionOptions GetSessionOptions(
|
||||
const OnlinePunctuationModelConfig &config) {
|
||||
return GetSessionOptionsImpl(config.num_threads, config.provider);
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
@@ -8,53 +8,28 @@
|
||||
#include <string>
|
||||
|
||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||
#include "sherpa-onnx/csrc/audio-tagging-model-config.h"
|
||||
#include "sherpa-onnx/csrc/offline-lm-config.h"
|
||||
#include "sherpa-onnx/csrc/offline-model-config.h"
|
||||
#include "sherpa-onnx/csrc/offline-punctuation-model-config.h"
|
||||
#include "sherpa-onnx/csrc/online-punctuation-model-config.h"
|
||||
#include "sherpa-onnx/csrc/online-lm-config.h"
|
||||
#include "sherpa-onnx/csrc/online-model-config.h"
|
||||
#include "sherpa-onnx/csrc/speaker-embedding-extractor.h"
|
||||
#include "sherpa-onnx/csrc/spoken-language-identification.h"
|
||||
#include "sherpa-onnx/csrc/vad-model-config.h"
|
||||
|
||||
#if SHERPA_ONNX_ENABLE_TTS
|
||||
#include "sherpa-onnx/csrc/offline-tts-model-config.h"
|
||||
#endif
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
Ort::SessionOptions GetSessionOptionsImpl(
|
||||
int32_t num_threads, const std::string &provider_str,
|
||||
const ProviderConfig *provider_config = nullptr);
|
||||
|
||||
Ort::SessionOptions GetSessionOptions(const OfflineLMConfig &config);
|
||||
Ort::SessionOptions GetSessionOptions(const OnlineLMConfig &config);
|
||||
|
||||
Ort::SessionOptions GetSessionOptions(const OnlineModelConfig &config);
|
||||
|
||||
Ort::SessionOptions GetSessionOptions(const OnlineModelConfig &config,
|
||||
const std::string &model_type);
|
||||
|
||||
Ort::SessionOptions GetSessionOptions(const OfflineModelConfig &config);
|
||||
|
||||
Ort::SessionOptions GetSessionOptions(const OfflineLMConfig &config);
|
||||
|
||||
Ort::SessionOptions GetSessionOptions(const OnlineLMConfig &config);
|
||||
|
||||
Ort::SessionOptions GetSessionOptions(const VadModelConfig &config);
|
||||
|
||||
#if SHERPA_ONNX_ENABLE_TTS
|
||||
Ort::SessionOptions GetSessionOptions(const OfflineTtsModelConfig &config);
|
||||
#endif
|
||||
|
||||
Ort::SessionOptions GetSessionOptions(
|
||||
const SpeakerEmbeddingExtractorConfig &config);
|
||||
|
||||
Ort::SessionOptions GetSessionOptions(
|
||||
const SpokenLanguageIdentificationConfig &config);
|
||||
|
||||
Ort::SessionOptions GetSessionOptions(const AudioTaggingModelConfig &config);
|
||||
|
||||
Ort::SessionOptions GetSessionOptions(
|
||||
const OfflinePunctuationModelConfig &config);
|
||||
|
||||
Ort::SessionOptions GetSessionOptions(
|
||||
const OnlinePunctuationModelConfig &config);
|
||||
template <typename T>
|
||||
Ort::SessionOptions GetSessionOptions(const T &config) {
|
||||
return GetSessionOptionsImpl(config.num_threads, config.provider);
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
|
||||
133
sherpa-onnx/csrc/sherpa-onnx-offline-speaker-diarization.cc
Normal file
133
sherpa-onnx/csrc/sherpa-onnx-offline-speaker-diarization.cc
Normal file
@@ -0,0 +1,133 @@
|
||||
// sherpa-onnx/csrc/sherpa-onnx-offline-speaker-diarization.cc
|
||||
//
|
||||
// Copyright (c) 2024 Xiaomi Corporation
|
||||
|
||||
#include "sherpa-onnx/csrc/offline-speaker-diarization.h"
|
||||
#include "sherpa-onnx/csrc/parse-options.h"
|
||||
#include "sherpa-onnx/csrc/wave-reader.h"
|
||||
|
||||
static int32_t ProgressCallback(int32_t processed_chunks, int32_t num_chunks,
|
||||
void *arg) {
|
||||
float progress = 100.0 * processed_chunks / num_chunks;
|
||||
fprintf(stderr, "progress %.2f%%\n", progress);
|
||||
|
||||
// the return value is currently ignored
|
||||
return 0;
|
||||
}
|
||||
|
||||
int main(int32_t argc, char *argv[]) {
|
||||
const char *kUsageMessage = R"usage(
|
||||
Offline/Non-streaming speaker diarization with sherpa-onnx
|
||||
Usage example:
|
||||
|
||||
Step 1: Download a speaker segmentation model
|
||||
|
||||
Please visit https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-segmentation-models
|
||||
for a list of available models. The following is an example
|
||||
|
||||
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/sherpa-onnx-pyannote-segmentation-3-0.tar.bz2
|
||||
tar xvf sherpa-onnx-pyannote-segmentation-3-0.tar.bz2
|
||||
rm sherpa-onnx-pyannote-segmentation-3-0.tar.bz2
|
||||
|
||||
Step 2: Download a speaker embedding extractor model
|
||||
|
||||
Please visit https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-recongition-models
|
||||
for a list of available models. The following is an example
|
||||
|
||||
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx
|
||||
|
||||
Step 3. Download test wave files
|
||||
|
||||
Please visit https://github.com/k2-fsa/sherpa-onnx/releases/tag/speaker-segmentation-models
|
||||
for a list of available test wave files. The following is an example
|
||||
|
||||
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-segmentation-models/0-four-speakers-zh.wav
|
||||
|
||||
Step 4. Build sherpa-onnx
|
||||
|
||||
Step 5. Run it
|
||||
|
||||
./bin/sherpa-onnx-offline-speaker-diarization \
|
||||
--clustering.num-clusters=4 \
|
||||
--segmentation.pyannote-model=./sherpa-onnx-pyannote-segmentation-3-0/model.onnx \
|
||||
--embedding.model=./3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx \
|
||||
./0-four-speakers-zh.wav
|
||||
|
||||
Since we know that there are four speakers in the test wave file, we use
|
||||
--clustering.num-clusters=4 in the above example.
|
||||
|
||||
If we don't know number of speakers in the given wave file, we can use
|
||||
the argument --clustering.cluster-threshold. The following is an example:
|
||||
|
||||
./bin/sherpa-onnx-offline-speaker-diarization \
|
||||
--clustering.cluster-threshold=0.90 \
|
||||
--segmentation.pyannote-model=./sherpa-onnx-pyannote-segmentation-3-0/model.onnx \
|
||||
--embedding.model=./3dspeaker_speech_eres2net_base_sv_zh-cn_3dspeaker_16k.onnx \
|
||||
./0-four-speakers-zh.wav
|
||||
|
||||
A larger threshold leads to few clusters, i.e., few speakers;
|
||||
a smaller threshold leads to more clusters, i.e., more speakers
|
||||
)usage";
|
||||
sherpa_onnx::OfflineSpeakerDiarizationConfig config;
|
||||
sherpa_onnx::ParseOptions po(kUsageMessage);
|
||||
config.Register(&po);
|
||||
po.Read(argc, argv);
|
||||
|
||||
std::cout << config.ToString() << "\n";
|
||||
|
||||
if (!config.Validate()) {
|
||||
po.PrintUsage();
|
||||
std::cerr << "Errors in config!\n";
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (po.NumArgs() != 1) {
|
||||
std::cerr << "Error: Please provide exactly 1 wave file.\n\n";
|
||||
po.PrintUsage();
|
||||
return -1;
|
||||
}
|
||||
|
||||
sherpa_onnx::OfflineSpeakerDiarization sd(config);
|
||||
|
||||
std::cout << "Started\n";
|
||||
const auto begin = std::chrono::steady_clock::now();
|
||||
const std::string wav_filename = po.GetArg(1);
|
||||
int32_t sample_rate = -1;
|
||||
bool is_ok = false;
|
||||
const std::vector<float> samples =
|
||||
sherpa_onnx::ReadWave(wav_filename, &sample_rate, &is_ok);
|
||||
if (!is_ok) {
|
||||
std::cerr << "Failed to read " << wav_filename.c_str() << "\n";
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (sample_rate != sd.SampleRate()) {
|
||||
std::cerr << "Expect sample rate " << sd.SampleRate()
|
||||
<< ". Given: " << sample_rate << "\n";
|
||||
return -1;
|
||||
}
|
||||
|
||||
float duration = samples.size() / static_cast<float>(sample_rate);
|
||||
|
||||
auto result =
|
||||
sd.Process(samples.data(), samples.size(), ProgressCallback, nullptr)
|
||||
.SortByStartTime();
|
||||
|
||||
for (const auto &r : result) {
|
||||
std::cout << r.ToString() << "\n";
|
||||
}
|
||||
|
||||
const auto end = std::chrono::steady_clock::now();
|
||||
float elapsed_seconds =
|
||||
std::chrono::duration_cast<std::chrono::milliseconds>(end - begin)
|
||||
.count() /
|
||||
1000.;
|
||||
|
||||
fprintf(stderr, "Duration : %.3f s\n", duration);
|
||||
fprintf(stderr, "Elapsed seconds: %.3f s\n", elapsed_seconds);
|
||||
float rtf = elapsed_seconds / duration;
|
||||
fprintf(stderr, "Real time factor (RTF): %.3f / %.3f = %.3f\n",
|
||||
elapsed_seconds, duration, rtf);
|
||||
|
||||
return 0;
|
||||
}
|
||||
@@ -9,14 +9,15 @@
|
||||
#include "sherpa-onnx/csrc/parse-options.h"
|
||||
#include "sherpa-onnx/csrc/wave-writer.h"
|
||||
|
||||
int32_t audioCallback(const float * /*samples*/, int32_t n, float progress) {
|
||||
static int32_t AudioCallback(const float * /*samples*/, int32_t n,
|
||||
float progress) {
|
||||
printf("sample=%d, progress=%f\n", n, progress);
|
||||
return 1;
|
||||
}
|
||||
|
||||
int main(int32_t argc, char *argv[]) {
|
||||
const char *kUsageMessage = R"usage(
|
||||
Offline text-to-speech with sherpa-onnx
|
||||
Offline/Non-streaming text-to-speech with sherpa-onnx
|
||||
|
||||
Usage example:
|
||||
|
||||
@@ -79,7 +80,7 @@ or details.
|
||||
sherpa_onnx::OfflineTts tts(config);
|
||||
|
||||
const auto begin = std::chrono::steady_clock::now();
|
||||
auto audio = tts.Generate(po.GetArg(1), sid, 1.0, audioCallback);
|
||||
auto audio = tts.Generate(po.GetArg(1), sid, 1.0, AudioCallback);
|
||||
const auto end = std::chrono::steady_clock::now();
|
||||
|
||||
if (audio.samples.empty()) {
|
||||
|
||||
@@ -19,7 +19,7 @@ The input text can contain English words.
|
||||
Usage:
|
||||
|
||||
Please download the model from:
|
||||
https://github.com/k2-fsa/sherpa-onnx/releases/download/punctuation-models/sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12.tar.bz2
|
||||
https://github.com/k2-fsa/sherpa-onnx/releases/download/punctuation-models/sherpa-onnx-online-punct-en-2024-08-06.tar.bz2
|
||||
|
||||
./bin/Release/sherpa-onnx-online-punctuation \
|
||||
--cnn-bilstm=/path/to/model.onnx \
|
||||
|
||||
@@ -26,12 +26,12 @@ void SpeakerEmbeddingExtractorConfig::Register(ParseOptions *po) {
|
||||
|
||||
bool SpeakerEmbeddingExtractorConfig::Validate() const {
|
||||
if (model.empty()) {
|
||||
SHERPA_ONNX_LOGE("Please provide --model");
|
||||
SHERPA_ONNX_LOGE("Please provide a speaker embedding extractor model");
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!FileExists(model)) {
|
||||
SHERPA_ONNX_LOGE("--speaker-embedding-model: '%s' does not exist",
|
||||
SHERPA_ONNX_LOGE("speaker embedding extractor model: '%s' does not exist",
|
||||
model.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user