Add C++ runtime for vocos (#2014)

This commit is contained in:
Fangjun Kuang
2025-03-17 17:05:15 +08:00
committed by GitHub
parent 623cdc9eec
commit 0aacf02dd8
62 changed files with 558 additions and 162 deletions

View File

@@ -13,7 +13,6 @@
#include "fst/extensions/far/far.h"
#include "kaldifst/csrc/kaldi-fst-io.h"
#include "kaldifst/csrc/text-normalizer.h"
#include "sherpa-onnx/csrc/hifigan-vocoder.h"
#include "sherpa-onnx/csrc/jieba-lexicon.h"
#include "sherpa-onnx/csrc/lexicon.h"
#include "sherpa-onnx/csrc/macros.h"
@@ -25,6 +24,7 @@
#include "sherpa-onnx/csrc/onnx-utils.h"
#include "sherpa-onnx/csrc/piper-phonemize-lexicon.h"
#include "sherpa-onnx/csrc/text-utils.h"
#include "sherpa-onnx/csrc/vocoder.h"
namespace sherpa_onnx {
@@ -33,9 +33,7 @@ class OfflineTtsMatchaImpl : public OfflineTtsImpl {
explicit OfflineTtsMatchaImpl(const OfflineTtsConfig &config)
: config_(config),
model_(std::make_unique<OfflineTtsMatchaModel>(config.model)),
vocoder_(std::make_unique<HifiganVocoder>(
config.model.num_threads, config.model.provider,
config.model.matcha.vocoder)) {
vocoder_(Vocoder::Create(config.model)) {
InitFrontend();
if (!config.rule_fsts.empty()) {
@@ -92,9 +90,7 @@ class OfflineTtsMatchaImpl : public OfflineTtsImpl {
OfflineTtsMatchaImpl(Manager *mgr, const OfflineTtsConfig &config)
: config_(config),
model_(std::make_unique<OfflineTtsMatchaModel>(mgr, config.model)),
vocoder_(std::make_unique<HifiganVocoder>(
mgr, config.model.num_threads, config.model.provider,
config.model.matcha.vocoder)) {
vocoder_(Vocoder::Create(mgr, config.model)) {
InitFrontend(mgr);
if (!config.rule_fsts.empty()) {
@@ -382,22 +378,11 @@ class OfflineTtsMatchaImpl : public OfflineTtsImpl {
memory_info, x.data(), x.size(), x_shape.data(), x_shape.size());
Ort::Value mel = model_->Run(std::move(x_tensor), sid, speed);
Ort::Value audio = vocoder_->Run(std::move(mel));
std::vector<int64_t> audio_shape =
audio.GetTensorTypeAndShapeInfo().GetShape();
int64_t total = 1;
// The output shape may be (1, 1, total) or (1, total) or (total,)
for (auto i : audio_shape) {
total *= i;
}
const float *p = audio.GetTensorData<float>();
GeneratedAudio ans;
ans.samples = vocoder_->Run(std::move(mel));
ans.sample_rate = model_->GetMetaData().sample_rate;
ans.samples = std::vector<float>(p, p + total);
float silence_scale = config_.silence_scale;
if (silence_scale != 1) {
@@ -410,7 +395,7 @@ class OfflineTtsMatchaImpl : public OfflineTtsImpl {
private:
OfflineTtsConfig config_;
std::unique_ptr<OfflineTtsMatchaModel> model_;
std::unique_ptr<HifiganVocoder> vocoder_;
std::unique_ptr<Vocoder> vocoder_;
std::vector<std::unique_ptr<kaldifst::TextNormalizer>> tn_list_;
std::unique_ptr<OfflineTtsFrontend> frontend_;
};