Add C++ runtime for models from 3d-speaker (#523)
This commit is contained in:
101
sherpa-onnx/csrc/speaker-embedding-extractor-general-impl.h
Normal file
101
sherpa-onnx/csrc/speaker-embedding-extractor-general-impl.h
Normal file
@@ -0,0 +1,101 @@
|
||||
// sherpa-onnx/csrc/speaker-embedding-extractor-general-impl.h
|
||||
//
|
||||
// Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
#ifndef SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_GENERAL_IMPL_H_
|
||||
#define SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_GENERAL_IMPL_H_
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "Eigen/Dense"
|
||||
#include "sherpa-onnx/csrc/speaker-embedding-extractor-impl.h"
|
||||
#include "sherpa-onnx/csrc/speaker-embedding-extractor-model.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
class SpeakerEmbeddingExtractorGeneralImpl
|
||||
: public SpeakerEmbeddingExtractorImpl {
|
||||
public:
|
||||
explicit SpeakerEmbeddingExtractorGeneralImpl(
|
||||
const SpeakerEmbeddingExtractorConfig &config)
|
||||
: model_(config) {}
|
||||
|
||||
int32_t Dim() const override { return model_.GetMetaData().output_dim; }
|
||||
|
||||
std::unique_ptr<OnlineStream> CreateStream() const override {
|
||||
FeatureExtractorConfig feat_config;
|
||||
const auto &meta_data = model_.GetMetaData();
|
||||
feat_config.sampling_rate = meta_data.sample_rate;
|
||||
feat_config.normalize_samples = meta_data.normalize_samples;
|
||||
|
||||
return std::make_unique<OnlineStream>(feat_config);
|
||||
}
|
||||
|
||||
bool IsReady(OnlineStream *s) const override {
|
||||
return s->GetNumProcessedFrames() < s->NumFramesReady();
|
||||
}
|
||||
|
||||
std::vector<float> Compute(OnlineStream *s) const override {
|
||||
int32_t num_frames = s->NumFramesReady() - s->GetNumProcessedFrames();
|
||||
if (num_frames <= 0) {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"Please make sure IsReady(s) returns true. num_frames: %d",
|
||||
num_frames);
|
||||
return {};
|
||||
}
|
||||
|
||||
std::vector<float> features =
|
||||
s->GetFrames(s->GetNumProcessedFrames(), num_frames);
|
||||
|
||||
s->GetNumProcessedFrames() += num_frames;
|
||||
|
||||
int32_t feat_dim = features.size() / num_frames;
|
||||
|
||||
const auto &meta_data = model_.GetMetaData();
|
||||
if (!meta_data.feature_normalize_type.empty()) {
|
||||
if (meta_data.feature_normalize_type == "global-mean") {
|
||||
SubtractGlobalMean(features.data(), num_frames, feat_dim);
|
||||
} else {
|
||||
SHERPA_ONNX_LOGE("Unsupported feature_normalize_type: %s",
|
||||
meta_data.feature_normalize_type.c_str());
|
||||
exit(-1);
|
||||
}
|
||||
}
|
||||
|
||||
auto memory_info =
|
||||
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
||||
|
||||
std::array<int64_t, 3> x_shape{1, num_frames, feat_dim};
|
||||
Ort::Value x =
|
||||
Ort::Value::CreateTensor(memory_info, features.data(), features.size(),
|
||||
x_shape.data(), x_shape.size());
|
||||
Ort::Value embedding = model_.Compute(std::move(x));
|
||||
std::vector<int64_t> embedding_shape =
|
||||
embedding.GetTensorTypeAndShapeInfo().GetShape();
|
||||
|
||||
std::vector<float> ans(embedding_shape[1]);
|
||||
std::copy(embedding.GetTensorData<float>(),
|
||||
embedding.GetTensorData<float>() + ans.size(), ans.begin());
|
||||
|
||||
return ans;
|
||||
}
|
||||
|
||||
private:
|
||||
void SubtractGlobalMean(float *p, int32_t num_frames,
|
||||
int32_t feat_dim) const {
|
||||
auto m = Eigen::Map<
|
||||
Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>(
|
||||
p, num_frames, feat_dim);
|
||||
|
||||
m = m.rowwise() - m.colwise().mean();
|
||||
}
|
||||
|
||||
private:
|
||||
SpeakerEmbeddingExtractorModel model_;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_SPEAKER_EMBEDDING_EXTRACTOR_GENERAL_IMPL_H_
|
||||
Reference in New Issue
Block a user