Break text into sentences for tts. (#460)

This is for models that are not using piper-phonemize as their front-end.
This commit is contained in:
Fangjun Kuang
2023-12-03 11:50:25 +08:00
committed by GitHub
parent 99ff6a834c
commit 86b4be5260
3 changed files with 71 additions and 76 deletions

View File

@@ -88,8 +88,8 @@ static std::vector<int32_t> ConvertTokensToIds(
Lexicon::Lexicon(const std::string &lexicon, const std::string &tokens, Lexicon::Lexicon(const std::string &lexicon, const std::string &tokens,
const std::string &punctuations, const std::string &language, const std::string &punctuations, const std::string &language,
bool debug /*= false*/, bool is_piper /*= false*/) bool debug /*= false*/)
: debug_(debug), is_piper_(is_piper) { : debug_(debug) {
InitLanguage(language); InitLanguage(language);
{ {
@@ -108,9 +108,9 @@ Lexicon::Lexicon(const std::string &lexicon, const std::string &tokens,
#if __ANDROID_API__ >= 9 #if __ANDROID_API__ >= 9
Lexicon::Lexicon(AAssetManager *mgr, const std::string &lexicon, Lexicon::Lexicon(AAssetManager *mgr, const std::string &lexicon,
const std::string &tokens, const std::string &punctuations, const std::string &tokens, const std::string &punctuations,
const std::string &language, bool debug /*= false*/, const std::string &language, bool debug /*= false*/
bool is_piper /*= false*/) )
: debug_(debug), is_piper_(is_piper) { : debug_(debug) {
InitLanguage(language); InitLanguage(language);
{ {
@@ -132,16 +132,10 @@ Lexicon::Lexicon(AAssetManager *mgr, const std::string &lexicon,
std::vector<std::vector<int64_t>> Lexicon::ConvertTextToTokenIds( std::vector<std::vector<int64_t>> Lexicon::ConvertTextToTokenIds(
const std::string &text, const std::string & /*voice*/ /*= ""*/) const { const std::string &text, const std::string & /*voice*/ /*= ""*/) const {
switch (language_) { switch (language_) {
case Language::kEnglish:
return ConvertTextToTokenIdsEnglish(text);
case Language::kGerman:
return ConvertTextToTokenIdsGerman(text);
case Language::kSpanish:
return ConvertTextToTokenIdsSpanish(text);
case Language::kFrench:
return ConvertTextToTokenIdsFrench(text);
case Language::kChinese: case Language::kChinese:
return ConvertTextToTokenIdsChinese(text); return ConvertTextToTokenIdsChinese(text);
case Language::kNotChinese:
return ConvertTextToTokenIdsNotChinese(text);
default: default:
SHERPA_ONNX_LOGE("Unknown language: %d", static_cast<int32_t>(language_)); SHERPA_ONNX_LOGE("Unknown language: %d", static_cast<int32_t>(language_));
exit(-1); exit(-1);
@@ -197,7 +191,8 @@ std::vector<std::vector<int64_t>> Lexicon::ConvertTextToTokenIdsChinese(
fprintf(stderr, "\n"); fprintf(stderr, "\n");
} }
std::vector<int64_t> ans; std::vector<std::vector<int64_t>> ans;
std::vector<int64_t> this_sentence;
int32_t blank = -1; int32_t blank = -1;
if (token2id_.count(" ")) { if (token2id_.count(" ")) {
@@ -212,15 +207,32 @@ std::vector<std::vector<int64_t>> Lexicon::ConvertTextToTokenIdsChinese(
} }
if (sil != -1) { if (sil != -1) {
ans.push_back(sil); this_sentence.push_back(sil);
} }
for (const auto &w : words) { for (const auto &w : words) {
if (punctuations_.count(w)) { if (w == "." || w == ";" || w == "!" || w == "?" || w == "-" || w == ":" ||
if (token2id_.count(w)) { w == "" || w == "" || w == "" || w == "" || w == "" ||
ans.push_back(token2id_.at(w)); w == "" ||
} else if (sil != -1) { // not sentence break
ans.push_back(sil); w == "," || w == "" || w == "" || w == "") {
if (punctuations_.count(w)) {
if (token2id_.count(w)) {
this_sentence.push_back(token2id_.at(w));
} else if (sil != -1) {
this_sentence.push_back(sil);
}
}
if (w != "," && w != "" && w != "" && w != "") {
if (eos != -1) {
this_sentence.push_back(eos);
}
ans.push_back(std::move(this_sentence));
if (sil != -1) {
this_sentence.push_back(sil);
}
} }
continue; continue;
} }
@@ -231,24 +243,26 @@ std::vector<std::vector<int64_t>> Lexicon::ConvertTextToTokenIdsChinese(
} }
const auto &token_ids = word2ids_.at(w); const auto &token_ids = word2ids_.at(w);
ans.insert(ans.end(), token_ids.begin(), token_ids.end()); this_sentence.insert(this_sentence.end(), token_ids.begin(),
token_ids.end());
if (blank != -1) { if (blank != -1) {
ans.push_back(blank); this_sentence.push_back(blank);
} }
} }
if (sil != -1) { if (sil != -1) {
ans.push_back(sil); this_sentence.push_back(sil);
} }
if (eos != -1) { if (eos != -1) {
ans.push_back(eos); this_sentence.push_back(eos);
} }
ans.push_back(std::move(this_sentence));
return {ans}; return ans;
} }
std::vector<std::vector<int64_t>> Lexicon::ConvertTextToTokenIdsEnglish( std::vector<std::vector<int64_t>> Lexicon::ConvertTextToTokenIdsNotChinese(
const std::string &_text) const { const std::string &_text) const {
std::string text(_text); std::string text(_text);
ToLowerCase(&text); ToLowerCase(&text);
@@ -271,14 +285,22 @@ std::vector<std::vector<int64_t>> Lexicon::ConvertTextToTokenIdsEnglish(
int32_t blank = token2id_.at(" "); int32_t blank = token2id_.at(" ");
std::vector<int64_t> ans; std::vector<std::vector<int64_t>> ans;
if (is_piper_ && token2id_.count("^")) { std::vector<int64_t> this_sentence;
ans.push_back(token2id_.at("^")); // sos
}
for (const auto &w : words) { for (const auto &w : words) {
if (punctuations_.count(w)) { if (w == "." || w == ";" || w == "!" || w == "?" || w == "-" || w == ":" ||
ans.push_back(token2id_.at(w)); // not sentence break
w == ",") {
if (punctuations_.count(w)) {
this_sentence.push_back(token2id_.at(w));
}
if (w != ",") {
this_sentence.push_back(blank);
ans.push_back(std::move(this_sentence));
}
continue; continue;
} }
@@ -288,20 +310,21 @@ std::vector<std::vector<int64_t>> Lexicon::ConvertTextToTokenIdsEnglish(
} }
const auto &token_ids = word2ids_.at(w); const auto &token_ids = word2ids_.at(w);
ans.insert(ans.end(), token_ids.begin(), token_ids.end()); this_sentence.insert(this_sentence.end(), token_ids.begin(),
ans.push_back(blank); token_ids.end());
this_sentence.push_back(blank);
} }
if (!ans.empty()) { if (!this_sentence.empty()) {
// remove the last blank // remove the last blank
ans.resize(ans.size() - 1); this_sentence.resize(this_sentence.size() - 1);
} }
if (is_piper_ && token2id_.count("$")) { if (!this_sentence.empty()) {
ans.push_back(token2id_.at("$")); // eos ans.push_back(std::move(this_sentence));
} }
return {ans}; return ans;
} }
void Lexicon::InitTokens(std::istream &is) { token2id_ = ReadTokens(is); } void Lexicon::InitTokens(std::istream &is) { token2id_ = ReadTokens(is); }
@@ -309,16 +332,10 @@ void Lexicon::InitTokens(std::istream &is) { token2id_ = ReadTokens(is); }
void Lexicon::InitLanguage(const std::string &_lang) { void Lexicon::InitLanguage(const std::string &_lang) {
std::string lang(_lang); std::string lang(_lang);
ToLowerCase(&lang); ToLowerCase(&lang);
if (lang == "english") { if (lang == "chinese") {
language_ = Language::kEnglish;
} else if (lang == "german") {
language_ = Language::kGerman;
} else if (lang == "spanish") {
language_ = Language::kSpanish;
} else if (lang == "french") {
language_ = Language::kFrench;
} else if (lang == "chinese") {
language_ = Language::kChinese; language_ = Language::kChinese;
} else if (!lang.empty()) {
language_ = Language::kNotChinese;
} else { } else {
SHERPA_ONNX_LOGE("Unknown language: %s", _lang.c_str()); SHERPA_ONNX_LOGE("Unknown language: %s", _lang.c_str());
exit(-1); exit(-1);

View File

@@ -29,35 +29,19 @@ class Lexicon : public OfflineTtsFrontend {
// Note: for models from piper, we won't use this class. // Note: for models from piper, we won't use this class.
Lexicon(const std::string &lexicon, const std::string &tokens, Lexicon(const std::string &lexicon, const std::string &tokens,
const std::string &punctuations, const std::string &language, const std::string &punctuations, const std::string &language,
bool debug = false, bool is_piper = false); bool debug = false);
#if __ANDROID_API__ >= 9 #if __ANDROID_API__ >= 9
Lexicon(AAssetManager *mgr, const std::string &lexicon, Lexicon(AAssetManager *mgr, const std::string &lexicon,
const std::string &tokens, const std::string &punctuations, const std::string &tokens, const std::string &punctuations,
const std::string &language, bool debug = false, const std::string &language, bool debug = false);
bool is_piper = false);
#endif #endif
std::vector<std::vector<int64_t>> ConvertTextToTokenIds( std::vector<std::vector<int64_t>> ConvertTextToTokenIds(
const std::string &text, const std::string &voice = "") const override; const std::string &text, const std::string &voice = "") const override;
private: private:
std::vector<std::vector<int64_t>> ConvertTextToTokenIdsGerman( std::vector<std::vector<int64_t>> ConvertTextToTokenIdsNotChinese(
const std::string &text) const {
return ConvertTextToTokenIdsEnglish(text);
}
std::vector<std::vector<int64_t>> ConvertTextToTokenIdsSpanish(
const std::string &text) const {
return ConvertTextToTokenIdsEnglish(text);
}
std::vector<std::vector<int64_t>> ConvertTextToTokenIdsFrench(
const std::string &text) const {
return ConvertTextToTokenIdsEnglish(text);
}
std::vector<std::vector<int64_t>> ConvertTextToTokenIdsEnglish(
const std::string &text) const; const std::string &text) const;
std::vector<std::vector<int64_t>> ConvertTextToTokenIdsChinese( std::vector<std::vector<int64_t>> ConvertTextToTokenIdsChinese(
@@ -70,10 +54,7 @@ class Lexicon : public OfflineTtsFrontend {
private: private:
enum class Language { enum class Language {
kEnglish, kNotChinese,
kGerman,
kSpanish,
kFrench,
kChinese, kChinese,
kUnknown, kUnknown,
}; };
@@ -84,7 +65,6 @@ class Lexicon : public OfflineTtsFrontend {
std::unordered_map<std::string, int32_t> token2id_; std::unordered_map<std::string, int32_t> token2id_;
Language language_; Language language_;
bool debug_; bool debug_;
bool is_piper_;
// for Chinese polyphones // for Chinese polyphones
std::unique_ptr<std::regex> pattern_; std::unique_ptr<std::regex> pattern_;

View File

@@ -195,8 +195,7 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
} else { } else {
frontend_ = std::make_unique<Lexicon>( frontend_ = std::make_unique<Lexicon>(
mgr, config_.model.vits.lexicon, config_.model.vits.tokens, mgr, config_.model.vits.lexicon, config_.model.vits.tokens,
model_->Punctuations(), model_->Language(), config_.model.debug, model_->Punctuations(), model_->Language(), config_.model.debug);
model_->IsPiper());
} }
} }
#endif #endif
@@ -208,8 +207,7 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
} else { } else {
frontend_ = std::make_unique<Lexicon>( frontend_ = std::make_unique<Lexicon>(
config_.model.vits.lexicon, config_.model.vits.tokens, config_.model.vits.lexicon, config_.model.vits.tokens,
model_->Punctuations(), model_->Language(), config_.model.debug, model_->Punctuations(), model_->Language(), config_.model.debug);
model_->IsPiper());
} }
} }