Add C++ runtime for MeloTTS (#1138)

This commit is contained in:
Fangjun Kuang
2024-07-16 15:55:02 +08:00
committed by GitHub
parent 95485411fa
commit 960eb7529e
51 changed files with 693 additions and 156 deletions

View File

@@ -22,6 +22,7 @@
#include "sherpa-onnx/csrc/jieba-lexicon.h"
#include "sherpa-onnx/csrc/lexicon.h"
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/melo-tts-lexicon.h"
#include "sherpa-onnx/csrc/offline-tts-character-frontend.h"
#include "sherpa-onnx/csrc/offline-tts-frontend.h"
#include "sherpa-onnx/csrc/offline-tts-impl.h"
@@ -174,26 +175,47 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
}
}
std::vector<std::vector<int64_t>> x =
std::vector<TokenIDs> token_ids =
frontend_->ConvertTextToTokenIds(text, meta_data.voice);
if (x.empty() || (x.size() == 1 && x[0].empty())) {
if (token_ids.empty() ||
(token_ids.size() == 1 && token_ids[0].tokens.empty())) {
SHERPA_ONNX_LOGE("Failed to convert %s to token IDs", text.c_str());
return {};
}
std::vector<std::vector<int64_t>> x;
std::vector<std::vector<int64_t>> tones;
x.reserve(token_ids.size());
for (auto &i : token_ids) {
x.push_back(std::move(i.tokens));
}
if (!token_ids[0].tones.empty()) {
tones.reserve(token_ids.size());
for (auto &i : token_ids) {
tones.push_back(std::move(i.tones));
}
}
// TODO(fangjun): add blank inside the frontend, not here
if (meta_data.add_blank && config_.model.vits.data_dir.empty() &&
meta_data.frontend != "characters") {
for (auto &k : x) {
k = AddBlank(k);
}
for (auto &k : tones) {
k = AddBlank(k);
}
}
int32_t x_size = static_cast<int32_t>(x.size());
if (config_.max_num_sentences <= 0 || x_size <= config_.max_num_sentences) {
auto ans = Process(x, sid, speed);
auto ans = Process(x, tones, sid, speed);
if (callback) {
callback(ans.samples.data(), ans.samples.size(), 1.0);
}
@@ -202,9 +224,12 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
// the input text is too long, we process sentences within it in batches
// to avoid OOM. Batch size is config_.max_num_sentences
std::vector<std::vector<int64_t>> batch;
std::vector<std::vector<int64_t>> batch_x;
std::vector<std::vector<int64_t>> batch_tones;
int32_t batch_size = config_.max_num_sentences;
batch.reserve(config_.max_num_sentences);
batch_x.reserve(config_.max_num_sentences);
batch_tones.reserve(config_.max_num_sentences);
int32_t num_batches = x_size / batch_size;
if (config_.model.debug) {
@@ -221,12 +246,17 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
int32_t k = 0;
for (int32_t b = 0; b != num_batches && should_continue; ++b) {
batch.clear();
batch_x.clear();
batch_tones.clear();
for (int32_t i = 0; i != batch_size; ++i, ++k) {
batch.push_back(std::move(x[k]));
batch_x.push_back(std::move(x[k]));
if (!tones.empty()) {
batch_tones.push_back(std::move(tones[k]));
}
}
auto audio = Process(batch, sid, speed);
auto audio = Process(batch_x, batch_tones, sid, speed);
ans.sample_rate = audio.sample_rate;
ans.samples.insert(ans.samples.end(), audio.samples.begin(),
audio.samples.end());
@@ -239,14 +269,19 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
}
}
batch.clear();
batch_x.clear();
batch_tones.clear();
while (k < static_cast<int32_t>(x.size()) && should_continue) {
batch.push_back(std::move(x[k]));
batch_x.push_back(std::move(x[k]));
if (!tones.empty()) {
batch_tones.push_back(std::move(tones[k]));
}
++k;
}
if (!batch.empty()) {
auto audio = Process(batch, sid, speed);
if (!batch_x.empty()) {
auto audio = Process(batch_x, batch_tones, sid, speed);
ans.sample_rate = audio.sample_rate;
ans.samples.insert(ans.samples.end(), audio.samples.begin(),
audio.samples.end());
@@ -308,6 +343,12 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
if (meta_data.frontend == "characters") {
frontend_ = std::make_unique<OfflineTtsCharacterFrontend>(
config_.model.vits.tokens, meta_data);
} else if (meta_data.jieba && !config_.model.vits.dict_dir.empty() &&
meta_data.is_melo_tts) {
frontend_ = std::make_unique<MeloTtsLexicon>(
config_.model.vits.lexicon, config_.model.vits.tokens,
config_.model.vits.dict_dir, model_->GetMetaData(),
config_.model.debug);
} else if (meta_data.jieba && !config_.model.vits.dict_dir.empty()) {
frontend_ = std::make_unique<JiebaLexicon>(
config_.model.vits.lexicon, config_.model.vits.tokens,
@@ -344,6 +385,7 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
}
GeneratedAudio Process(const std::vector<std::vector<int64_t>> &tokens,
const std::vector<std::vector<int64_t>> &tones,
int32_t sid, float speed) const {
int32_t num_tokens = 0;
for (const auto &k : tokens) {
@@ -356,6 +398,14 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
x.insert(x.end(), k.begin(), k.end());
}
std::vector<int64_t> tone_list;
if (!tones.empty()) {
tone_list.reserve(num_tokens);
for (const auto &k : tones) {
tone_list.insert(tone_list.end(), k.begin(), k.end());
}
}
auto memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
@@ -363,7 +413,20 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
Ort::Value x_tensor = Ort::Value::CreateTensor(
memory_info, x.data(), x.size(), x_shape.data(), x_shape.size());
Ort::Value audio = model_->Run(std::move(x_tensor), sid, speed);
Ort::Value tones_tensor{nullptr};
if (!tones.empty()) {
tones_tensor = Ort::Value::CreateTensor(memory_info, tone_list.data(),
tone_list.size(), x_shape.data(),
x_shape.size());
}
Ort::Value audio{nullptr};
if (tones.empty()) {
audio = model_->Run(std::move(x_tensor), sid, speed);
} else {
audio =
model_->Run(std::move(x_tensor), std::move(tones_tensor), sid, speed);
}
std::vector<int64_t> audio_shape =
audio.GetTensorTypeAndShapeInfo().GetShape();