Use piper-phonemize to convert text to token IDs (#453)
This commit is contained in:
@@ -18,9 +18,11 @@
|
||||
#include "kaldifst/csrc/text-normalizer.h"
|
||||
#include "sherpa-onnx/csrc/lexicon.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/offline-tts-frontend.h"
|
||||
#include "sherpa-onnx/csrc/offline-tts-impl.h"
|
||||
#include "sherpa-onnx/csrc/offline-tts-vits-model.h"
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
#include "sherpa-onnx/csrc/piper-phonemize-lexicon.h"
|
||||
#include "sherpa-onnx/csrc/text-utils.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
@@ -29,10 +31,9 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
|
||||
public:
|
||||
explicit OfflineTtsVitsImpl(const OfflineTtsConfig &config)
|
||||
: config_(config),
|
||||
model_(std::make_unique<OfflineTtsVitsModel>(config.model)),
|
||||
lexicon_(config.model.vits.lexicon, config.model.vits.tokens,
|
||||
model_->Punctuations(), model_->Language(), config.model.debug,
|
||||
model_->IsPiper()) {
|
||||
model_(std::make_unique<OfflineTtsVitsModel>(config.model)) {
|
||||
InitFrontend();
|
||||
|
||||
if (!config.rule_fsts.empty()) {
|
||||
std::vector<std::string> files;
|
||||
SplitStringToVector(config.rule_fsts, ",", false, &files);
|
||||
@@ -49,10 +50,9 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
|
||||
#if __ANDROID_API__ >= 9
|
||||
OfflineTtsVitsImpl(AAssetManager *mgr, const OfflineTtsConfig &config)
|
||||
: config_(config),
|
||||
model_(std::make_unique<OfflineTtsVitsModel>(mgr, config.model)),
|
||||
lexicon_(mgr, config.model.vits.lexicon, config.model.vits.tokens,
|
||||
model_->Punctuations(), model_->Language(), config.model.debug,
|
||||
model_->IsPiper()) {
|
||||
model_(std::make_unique<OfflineTtsVitsModel>(mgr, config.model)) {
|
||||
InitFrontend(mgr);
|
||||
|
||||
if (!config.rule_fsts.empty()) {
|
||||
std::vector<std::string> files;
|
||||
SplitStringToVector(config.rule_fsts, ",", false, &files);
|
||||
@@ -101,20 +101,119 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<int64_t> x = lexicon_.ConvertTextToTokenIds(text);
|
||||
if (x.empty()) {
|
||||
std::vector<std::vector<int64_t>> x =
|
||||
frontend_->ConvertTextToTokenIds(text, model_->Voice());
|
||||
|
||||
if (x.empty() || (x.size() == 1 && x[0].empty())) {
|
||||
SHERPA_ONNX_LOGE("Failed to convert %s to token IDs", text.c_str());
|
||||
return {};
|
||||
}
|
||||
|
||||
if (model_->AddBlank()) {
|
||||
std::vector<int64_t> buffer(x.size() * 2 + 1);
|
||||
int32_t i = 1;
|
||||
for (auto k : x) {
|
||||
buffer[i] = k;
|
||||
i += 2;
|
||||
if (model_->AddBlank() && config_.model.vits.data_dir.empty()) {
|
||||
for (auto &k : x) {
|
||||
k = AddBlank(k);
|
||||
}
|
||||
x = std::move(buffer);
|
||||
}
|
||||
|
||||
int32_t x_size = static_cast<int32_t>(x.size());
|
||||
|
||||
if (config_.max_num_sentences <= 0 || x_size <= config_.max_num_sentences) {
|
||||
return Process(x, sid, speed);
|
||||
}
|
||||
|
||||
// the input text is too long, we process sentences within it in batches
|
||||
// to avoid OOM. Batch size is config_.max_num_sentences
|
||||
std::vector<std::vector<int64_t>> batch;
|
||||
int32_t batch_size = config_.max_num_sentences;
|
||||
batch.reserve(config_.max_num_sentences);
|
||||
int32_t num_batches = x_size / batch_size;
|
||||
|
||||
if (config_.model.debug) {
|
||||
SHERPA_ONNX_LOGE(
|
||||
"Text is too long. Split it into %d batches. batch size: %d. Number "
|
||||
"of sentences: %d",
|
||||
num_batches, batch_size, x_size);
|
||||
}
|
||||
|
||||
GeneratedAudio ans;
|
||||
|
||||
int32_t k = 0;
|
||||
|
||||
for (int32_t b = 0; b != num_batches; ++b) {
|
||||
batch.clear();
|
||||
for (int32_t i = 0; i != batch_size; ++i, ++k) {
|
||||
batch.push_back(std::move(x[k]));
|
||||
}
|
||||
|
||||
auto audio = Process(batch, sid, speed);
|
||||
ans.sample_rate = audio.sample_rate;
|
||||
ans.samples.insert(ans.samples.end(), audio.samples.begin(),
|
||||
audio.samples.end());
|
||||
}
|
||||
|
||||
batch.clear();
|
||||
while (k < x.size()) {
|
||||
batch.push_back(std::move(x[k]));
|
||||
++k;
|
||||
}
|
||||
|
||||
if (!batch.empty()) {
|
||||
auto audio = Process(batch, sid, speed);
|
||||
ans.sample_rate = audio.sample_rate;
|
||||
ans.samples.insert(ans.samples.end(), audio.samples.begin(),
|
||||
audio.samples.end());
|
||||
}
|
||||
|
||||
return ans;
|
||||
}
|
||||
|
||||
private:
|
||||
void InitFrontend(AAssetManager *mgr) {
|
||||
if (model_->IsPiper() && !config_.model.vits.data_dir.empty()) {
|
||||
frontend_ = std::make_unique<PiperPhonemizeLexicon>(
|
||||
mgr, config_.model.vits.tokens, config_.model.vits.data_dir);
|
||||
} else {
|
||||
frontend_ = std::make_unique<Lexicon>(
|
||||
mgr, config_.model.vits.lexicon, config_.model.vits.tokens,
|
||||
model_->Punctuations(), model_->Language(), config_.model.debug,
|
||||
model_->IsPiper());
|
||||
}
|
||||
}
|
||||
|
||||
void InitFrontend() {
|
||||
if (model_->IsPiper() && !config_.model.vits.data_dir.empty()) {
|
||||
frontend_ = std::make_unique<PiperPhonemizeLexicon>(
|
||||
config_.model.vits.tokens, config_.model.vits.data_dir);
|
||||
} else {
|
||||
frontend_ = std::make_unique<Lexicon>(
|
||||
config_.model.vits.lexicon, config_.model.vits.tokens,
|
||||
model_->Punctuations(), model_->Language(), config_.model.debug,
|
||||
model_->IsPiper());
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<int64_t> AddBlank(const std::vector<int64_t> &x) const {
|
||||
// we assume the blank ID is 0
|
||||
std::vector<int64_t> buffer(x.size() * 2 + 1);
|
||||
int32_t i = 1;
|
||||
for (auto k : x) {
|
||||
buffer[i] = k;
|
||||
i += 2;
|
||||
}
|
||||
return buffer;
|
||||
}
|
||||
|
||||
GeneratedAudio Process(const std::vector<std::vector<int64_t>> &tokens,
|
||||
int32_t sid, float speed) const {
|
||||
int32_t num_tokens = 0;
|
||||
for (const auto &k : tokens) {
|
||||
num_tokens += k.size();
|
||||
}
|
||||
|
||||
std::vector<int64_t> x;
|
||||
x.reserve(num_tokens);
|
||||
for (const auto &k : tokens) {
|
||||
x.insert(x.end(), k.begin(), k.end());
|
||||
}
|
||||
|
||||
auto memory_info =
|
||||
@@ -147,7 +246,7 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
|
||||
OfflineTtsConfig config_;
|
||||
std::unique_ptr<OfflineTtsVitsModel> model_;
|
||||
std::vector<std::unique_ptr<kaldifst::TextNormalizer>> tn_list_;
|
||||
Lexicon lexicon_;
|
||||
std::unique_ptr<OfflineTtsFrontend> frontend_;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
Reference in New Issue
Block a user