Add C++ runtime for MeloTTS (#1138)
This commit is contained in:
@@ -22,6 +22,7 @@
|
||||
#include "sherpa-onnx/csrc/jieba-lexicon.h"
|
||||
#include "sherpa-onnx/csrc/lexicon.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/melo-tts-lexicon.h"
|
||||
#include "sherpa-onnx/csrc/offline-tts-character-frontend.h"
|
||||
#include "sherpa-onnx/csrc/offline-tts-frontend.h"
|
||||
#include "sherpa-onnx/csrc/offline-tts-impl.h"
|
||||
@@ -174,26 +175,47 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> x =
|
||||
std::vector<TokenIDs> token_ids =
|
||||
frontend_->ConvertTextToTokenIds(text, meta_data.voice);
|
||||
|
||||
if (x.empty() || (x.size() == 1 && x[0].empty())) {
|
||||
if (token_ids.empty() ||
|
||||
(token_ids.size() == 1 && token_ids[0].tokens.empty())) {
|
||||
SHERPA_ONNX_LOGE("Failed to convert %s to token IDs", text.c_str());
|
||||
return {};
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> x;
|
||||
std::vector<std::vector<int64_t>> tones;
|
||||
|
||||
x.reserve(token_ids.size());
|
||||
|
||||
for (auto &i : token_ids) {
|
||||
x.push_back(std::move(i.tokens));
|
||||
}
|
||||
|
||||
if (!token_ids[0].tones.empty()) {
|
||||
tones.reserve(token_ids.size());
|
||||
for (auto &i : token_ids) {
|
||||
tones.push_back(std::move(i.tones));
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(fangjun): add blank inside the frontend, not here
|
||||
if (meta_data.add_blank && config_.model.vits.data_dir.empty() &&
|
||||
meta_data.frontend != "characters") {
|
||||
for (auto &k : x) {
|
||||
k = AddBlank(k);
|
||||
}
|
||||
|
||||
for (auto &k : tones) {
|
||||
k = AddBlank(k);
|
||||
}
|
||||
}
|
||||
|
||||
int32_t x_size = static_cast<int32_t>(x.size());
|
||||
|
||||
if (config_.max_num_sentences <= 0 || x_size <= config_.max_num_sentences) {
|
||||
auto ans = Process(x, sid, speed);
|
||||
auto ans = Process(x, tones, sid, speed);
|
||||
if (callback) {
|
||||
callback(ans.samples.data(), ans.samples.size(), 1.0);
|
||||
}
|
||||
@@ -202,9 +224,12 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
|
||||
|
||||
// the input text is too long, we process sentences within it in batches
|
||||
// to avoid OOM. Batch size is config_.max_num_sentences
|
||||
std::vector<std::vector<int64_t>> batch;
|
||||
std::vector<std::vector<int64_t>> batch_x;
|
||||
std::vector<std::vector<int64_t>> batch_tones;
|
||||
|
||||
int32_t batch_size = config_.max_num_sentences;
|
||||
batch.reserve(config_.max_num_sentences);
|
||||
batch_x.reserve(config_.max_num_sentences);
|
||||
batch_tones.reserve(config_.max_num_sentences);
|
||||
int32_t num_batches = x_size / batch_size;
|
||||
|
||||
if (config_.model.debug) {
|
||||
@@ -221,12 +246,17 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
|
||||
int32_t k = 0;
|
||||
|
||||
for (int32_t b = 0; b != num_batches && should_continue; ++b) {
|
||||
batch.clear();
|
||||
batch_x.clear();
|
||||
batch_tones.clear();
|
||||
for (int32_t i = 0; i != batch_size; ++i, ++k) {
|
||||
batch.push_back(std::move(x[k]));
|
||||
batch_x.push_back(std::move(x[k]));
|
||||
|
||||
if (!tones.empty()) {
|
||||
batch_tones.push_back(std::move(tones[k]));
|
||||
}
|
||||
}
|
||||
|
||||
auto audio = Process(batch, sid, speed);
|
||||
auto audio = Process(batch_x, batch_tones, sid, speed);
|
||||
ans.sample_rate = audio.sample_rate;
|
||||
ans.samples.insert(ans.samples.end(), audio.samples.begin(),
|
||||
audio.samples.end());
|
||||
@@ -239,14 +269,19 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
|
||||
}
|
||||
}
|
||||
|
||||
batch.clear();
|
||||
batch_x.clear();
|
||||
batch_tones.clear();
|
||||
while (k < static_cast<int32_t>(x.size()) && should_continue) {
|
||||
batch.push_back(std::move(x[k]));
|
||||
batch_x.push_back(std::move(x[k]));
|
||||
if (!tones.empty()) {
|
||||
batch_tones.push_back(std::move(tones[k]));
|
||||
}
|
||||
|
||||
++k;
|
||||
}
|
||||
|
||||
if (!batch.empty()) {
|
||||
auto audio = Process(batch, sid, speed);
|
||||
if (!batch_x.empty()) {
|
||||
auto audio = Process(batch_x, batch_tones, sid, speed);
|
||||
ans.sample_rate = audio.sample_rate;
|
||||
ans.samples.insert(ans.samples.end(), audio.samples.begin(),
|
||||
audio.samples.end());
|
||||
@@ -308,6 +343,12 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
|
||||
if (meta_data.frontend == "characters") {
|
||||
frontend_ = std::make_unique<OfflineTtsCharacterFrontend>(
|
||||
config_.model.vits.tokens, meta_data);
|
||||
} else if (meta_data.jieba && !config_.model.vits.dict_dir.empty() &&
|
||||
meta_data.is_melo_tts) {
|
||||
frontend_ = std::make_unique<MeloTtsLexicon>(
|
||||
config_.model.vits.lexicon, config_.model.vits.tokens,
|
||||
config_.model.vits.dict_dir, model_->GetMetaData(),
|
||||
config_.model.debug);
|
||||
} else if (meta_data.jieba && !config_.model.vits.dict_dir.empty()) {
|
||||
frontend_ = std::make_unique<JiebaLexicon>(
|
||||
config_.model.vits.lexicon, config_.model.vits.tokens,
|
||||
@@ -344,6 +385,7 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
|
||||
}
|
||||
|
||||
GeneratedAudio Process(const std::vector<std::vector<int64_t>> &tokens,
|
||||
const std::vector<std::vector<int64_t>> &tones,
|
||||
int32_t sid, float speed) const {
|
||||
int32_t num_tokens = 0;
|
||||
for (const auto &k : tokens) {
|
||||
@@ -356,6 +398,14 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
|
||||
x.insert(x.end(), k.begin(), k.end());
|
||||
}
|
||||
|
||||
std::vector<int64_t> tone_list;
|
||||
if (!tones.empty()) {
|
||||
tone_list.reserve(num_tokens);
|
||||
for (const auto &k : tones) {
|
||||
tone_list.insert(tone_list.end(), k.begin(), k.end());
|
||||
}
|
||||
}
|
||||
|
||||
auto memory_info =
|
||||
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
||||
|
||||
@@ -363,7 +413,20 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
|
||||
Ort::Value x_tensor = Ort::Value::CreateTensor(
|
||||
memory_info, x.data(), x.size(), x_shape.data(), x_shape.size());
|
||||
|
||||
Ort::Value audio = model_->Run(std::move(x_tensor), sid, speed);
|
||||
Ort::Value tones_tensor{nullptr};
|
||||
if (!tones.empty()) {
|
||||
tones_tensor = Ort::Value::CreateTensor(memory_info, tone_list.data(),
|
||||
tone_list.size(), x_shape.data(),
|
||||
x_shape.size());
|
||||
}
|
||||
|
||||
Ort::Value audio{nullptr};
|
||||
if (tones.empty()) {
|
||||
audio = model_->Run(std::move(x_tensor), sid, speed);
|
||||
} else {
|
||||
audio =
|
||||
model_->Run(std::move(x_tensor), std::move(tones_tensor), sid, speed);
|
||||
}
|
||||
|
||||
std::vector<int64_t> audio_shape =
|
||||
audio.GetTensorTypeAndShapeInfo().GetShape();
|
||||
|
||||
Reference in New Issue
Block a user