Add C++ runtime for vocos (#2014)
This commit is contained in:
@@ -13,7 +13,6 @@
|
||||
#include "fst/extensions/far/far.h"
|
||||
#include "kaldifst/csrc/kaldi-fst-io.h"
|
||||
#include "kaldifst/csrc/text-normalizer.h"
|
||||
#include "sherpa-onnx/csrc/hifigan-vocoder.h"
|
||||
#include "sherpa-onnx/csrc/jieba-lexicon.h"
|
||||
#include "sherpa-onnx/csrc/lexicon.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
@@ -25,6 +24,7 @@
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
#include "sherpa-onnx/csrc/piper-phonemize-lexicon.h"
|
||||
#include "sherpa-onnx/csrc/text-utils.h"
|
||||
#include "sherpa-onnx/csrc/vocoder.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
@@ -33,9 +33,7 @@ class OfflineTtsMatchaImpl : public OfflineTtsImpl {
|
||||
explicit OfflineTtsMatchaImpl(const OfflineTtsConfig &config)
|
||||
: config_(config),
|
||||
model_(std::make_unique<OfflineTtsMatchaModel>(config.model)),
|
||||
vocoder_(std::make_unique<HifiganVocoder>(
|
||||
config.model.num_threads, config.model.provider,
|
||||
config.model.matcha.vocoder)) {
|
||||
vocoder_(Vocoder::Create(config.model)) {
|
||||
InitFrontend();
|
||||
|
||||
if (!config.rule_fsts.empty()) {
|
||||
@@ -92,9 +90,7 @@ class OfflineTtsMatchaImpl : public OfflineTtsImpl {
|
||||
OfflineTtsMatchaImpl(Manager *mgr, const OfflineTtsConfig &config)
|
||||
: config_(config),
|
||||
model_(std::make_unique<OfflineTtsMatchaModel>(mgr, config.model)),
|
||||
vocoder_(std::make_unique<HifiganVocoder>(
|
||||
mgr, config.model.num_threads, config.model.provider,
|
||||
config.model.matcha.vocoder)) {
|
||||
vocoder_(Vocoder::Create(mgr, config.model)) {
|
||||
InitFrontend(mgr);
|
||||
|
||||
if (!config.rule_fsts.empty()) {
|
||||
@@ -382,22 +378,11 @@ class OfflineTtsMatchaImpl : public OfflineTtsImpl {
|
||||
memory_info, x.data(), x.size(), x_shape.data(), x_shape.size());
|
||||
|
||||
Ort::Value mel = model_->Run(std::move(x_tensor), sid, speed);
|
||||
Ort::Value audio = vocoder_->Run(std::move(mel));
|
||||
|
||||
std::vector<int64_t> audio_shape =
|
||||
audio.GetTensorTypeAndShapeInfo().GetShape();
|
||||
|
||||
int64_t total = 1;
|
||||
// The output shape may be (1, 1, total) or (1, total) or (total,)
|
||||
for (auto i : audio_shape) {
|
||||
total *= i;
|
||||
}
|
||||
|
||||
const float *p = audio.GetTensorData<float>();
|
||||
|
||||
GeneratedAudio ans;
|
||||
|
||||
ans.samples = vocoder_->Run(std::move(mel));
|
||||
ans.sample_rate = model_->GetMetaData().sample_rate;
|
||||
ans.samples = std::vector<float>(p, p + total);
|
||||
|
||||
float silence_scale = config_.silence_scale;
|
||||
if (silence_scale != 1) {
|
||||
@@ -410,7 +395,7 @@ class OfflineTtsMatchaImpl : public OfflineTtsImpl {
|
||||
private:
|
||||
OfflineTtsConfig config_;
|
||||
std::unique_ptr<OfflineTtsMatchaModel> model_;
|
||||
std::unique_ptr<HifiganVocoder> vocoder_;
|
||||
std::unique_ptr<Vocoder> vocoder_;
|
||||
std::vector<std::unique_ptr<kaldifst::TextNormalizer>> tn_list_;
|
||||
std::unique_ptr<OfflineTtsFrontend> frontend_;
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user