Break text into sentences for tts. (#460)
This is for models that are not using piper-phonemize as their front-end.
This commit is contained in:
@@ -88,8 +88,8 @@ static std::vector<int32_t> ConvertTokensToIds(
|
|||||||
|
|
||||||
Lexicon::Lexicon(const std::string &lexicon, const std::string &tokens,
|
Lexicon::Lexicon(const std::string &lexicon, const std::string &tokens,
|
||||||
const std::string &punctuations, const std::string &language,
|
const std::string &punctuations, const std::string &language,
|
||||||
bool debug /*= false*/, bool is_piper /*= false*/)
|
bool debug /*= false*/)
|
||||||
: debug_(debug), is_piper_(is_piper) {
|
: debug_(debug) {
|
||||||
InitLanguage(language);
|
InitLanguage(language);
|
||||||
|
|
||||||
{
|
{
|
||||||
@@ -108,9 +108,9 @@ Lexicon::Lexicon(const std::string &lexicon, const std::string &tokens,
|
|||||||
#if __ANDROID_API__ >= 9
|
#if __ANDROID_API__ >= 9
|
||||||
Lexicon::Lexicon(AAssetManager *mgr, const std::string &lexicon,
|
Lexicon::Lexicon(AAssetManager *mgr, const std::string &lexicon,
|
||||||
const std::string &tokens, const std::string &punctuations,
|
const std::string &tokens, const std::string &punctuations,
|
||||||
const std::string &language, bool debug /*= false*/,
|
const std::string &language, bool debug /*= false*/
|
||||||
bool is_piper /*= false*/)
|
)
|
||||||
: debug_(debug), is_piper_(is_piper) {
|
: debug_(debug) {
|
||||||
InitLanguage(language);
|
InitLanguage(language);
|
||||||
|
|
||||||
{
|
{
|
||||||
@@ -132,16 +132,10 @@ Lexicon::Lexicon(AAssetManager *mgr, const std::string &lexicon,
|
|||||||
std::vector<std::vector<int64_t>> Lexicon::ConvertTextToTokenIds(
|
std::vector<std::vector<int64_t>> Lexicon::ConvertTextToTokenIds(
|
||||||
const std::string &text, const std::string & /*voice*/ /*= ""*/) const {
|
const std::string &text, const std::string & /*voice*/ /*= ""*/) const {
|
||||||
switch (language_) {
|
switch (language_) {
|
||||||
case Language::kEnglish:
|
|
||||||
return ConvertTextToTokenIdsEnglish(text);
|
|
||||||
case Language::kGerman:
|
|
||||||
return ConvertTextToTokenIdsGerman(text);
|
|
||||||
case Language::kSpanish:
|
|
||||||
return ConvertTextToTokenIdsSpanish(text);
|
|
||||||
case Language::kFrench:
|
|
||||||
return ConvertTextToTokenIdsFrench(text);
|
|
||||||
case Language::kChinese:
|
case Language::kChinese:
|
||||||
return ConvertTextToTokenIdsChinese(text);
|
return ConvertTextToTokenIdsChinese(text);
|
||||||
|
case Language::kNotChinese:
|
||||||
|
return ConvertTextToTokenIdsNotChinese(text);
|
||||||
default:
|
default:
|
||||||
SHERPA_ONNX_LOGE("Unknown language: %d", static_cast<int32_t>(language_));
|
SHERPA_ONNX_LOGE("Unknown language: %d", static_cast<int32_t>(language_));
|
||||||
exit(-1);
|
exit(-1);
|
||||||
@@ -197,7 +191,8 @@ std::vector<std::vector<int64_t>> Lexicon::ConvertTextToTokenIdsChinese(
|
|||||||
fprintf(stderr, "\n");
|
fprintf(stderr, "\n");
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<int64_t> ans;
|
std::vector<std::vector<int64_t>> ans;
|
||||||
|
std::vector<int64_t> this_sentence;
|
||||||
|
|
||||||
int32_t blank = -1;
|
int32_t blank = -1;
|
||||||
if (token2id_.count(" ")) {
|
if (token2id_.count(" ")) {
|
||||||
@@ -212,15 +207,32 @@ std::vector<std::vector<int64_t>> Lexicon::ConvertTextToTokenIdsChinese(
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (sil != -1) {
|
if (sil != -1) {
|
||||||
ans.push_back(sil);
|
this_sentence.push_back(sil);
|
||||||
}
|
}
|
||||||
|
|
||||||
for (const auto &w : words) {
|
for (const auto &w : words) {
|
||||||
if (punctuations_.count(w)) {
|
if (w == "." || w == ";" || w == "!" || w == "?" || w == "-" || w == ":" ||
|
||||||
if (token2id_.count(w)) {
|
w == "。" || w == ";" || w == "!" || w == "?" || w == ":" ||
|
||||||
ans.push_back(token2id_.at(w));
|
w == "”" ||
|
||||||
} else if (sil != -1) {
|
// not sentence break
|
||||||
ans.push_back(sil);
|
w == "," || w == "“" || w == "," || w == "、") {
|
||||||
|
if (punctuations_.count(w)) {
|
||||||
|
if (token2id_.count(w)) {
|
||||||
|
this_sentence.push_back(token2id_.at(w));
|
||||||
|
} else if (sil != -1) {
|
||||||
|
this_sentence.push_back(sil);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (w != "," && w != "“" && w != "," && w != "、") {
|
||||||
|
if (eos != -1) {
|
||||||
|
this_sentence.push_back(eos);
|
||||||
|
}
|
||||||
|
ans.push_back(std::move(this_sentence));
|
||||||
|
|
||||||
|
if (sil != -1) {
|
||||||
|
this_sentence.push_back(sil);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
@@ -231,24 +243,26 @@ std::vector<std::vector<int64_t>> Lexicon::ConvertTextToTokenIdsChinese(
|
|||||||
}
|
}
|
||||||
|
|
||||||
const auto &token_ids = word2ids_.at(w);
|
const auto &token_ids = word2ids_.at(w);
|
||||||
ans.insert(ans.end(), token_ids.begin(), token_ids.end());
|
this_sentence.insert(this_sentence.end(), token_ids.begin(),
|
||||||
|
token_ids.end());
|
||||||
if (blank != -1) {
|
if (blank != -1) {
|
||||||
ans.push_back(blank);
|
this_sentence.push_back(blank);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (sil != -1) {
|
if (sil != -1) {
|
||||||
ans.push_back(sil);
|
this_sentence.push_back(sil);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (eos != -1) {
|
if (eos != -1) {
|
||||||
ans.push_back(eos);
|
this_sentence.push_back(eos);
|
||||||
}
|
}
|
||||||
|
ans.push_back(std::move(this_sentence));
|
||||||
|
|
||||||
return {ans};
|
return ans;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<std::vector<int64_t>> Lexicon::ConvertTextToTokenIdsEnglish(
|
std::vector<std::vector<int64_t>> Lexicon::ConvertTextToTokenIdsNotChinese(
|
||||||
const std::string &_text) const {
|
const std::string &_text) const {
|
||||||
std::string text(_text);
|
std::string text(_text);
|
||||||
ToLowerCase(&text);
|
ToLowerCase(&text);
|
||||||
@@ -271,14 +285,22 @@ std::vector<std::vector<int64_t>> Lexicon::ConvertTextToTokenIdsEnglish(
|
|||||||
|
|
||||||
int32_t blank = token2id_.at(" ");
|
int32_t blank = token2id_.at(" ");
|
||||||
|
|
||||||
std::vector<int64_t> ans;
|
std::vector<std::vector<int64_t>> ans;
|
||||||
if (is_piper_ && token2id_.count("^")) {
|
std::vector<int64_t> this_sentence;
|
||||||
ans.push_back(token2id_.at("^")); // sos
|
|
||||||
}
|
|
||||||
|
|
||||||
for (const auto &w : words) {
|
for (const auto &w : words) {
|
||||||
if (punctuations_.count(w)) {
|
if (w == "." || w == ";" || w == "!" || w == "?" || w == "-" || w == ":" ||
|
||||||
ans.push_back(token2id_.at(w));
|
// not sentence break
|
||||||
|
w == ",") {
|
||||||
|
if (punctuations_.count(w)) {
|
||||||
|
this_sentence.push_back(token2id_.at(w));
|
||||||
|
}
|
||||||
|
|
||||||
|
if (w != ",") {
|
||||||
|
this_sentence.push_back(blank);
|
||||||
|
ans.push_back(std::move(this_sentence));
|
||||||
|
}
|
||||||
|
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -288,20 +310,21 @@ std::vector<std::vector<int64_t>> Lexicon::ConvertTextToTokenIdsEnglish(
|
|||||||
}
|
}
|
||||||
|
|
||||||
const auto &token_ids = word2ids_.at(w);
|
const auto &token_ids = word2ids_.at(w);
|
||||||
ans.insert(ans.end(), token_ids.begin(), token_ids.end());
|
this_sentence.insert(this_sentence.end(), token_ids.begin(),
|
||||||
ans.push_back(blank);
|
token_ids.end());
|
||||||
|
this_sentence.push_back(blank);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!ans.empty()) {
|
if (!this_sentence.empty()) {
|
||||||
// remove the last blank
|
// remove the last blank
|
||||||
ans.resize(ans.size() - 1);
|
this_sentence.resize(this_sentence.size() - 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (is_piper_ && token2id_.count("$")) {
|
if (!this_sentence.empty()) {
|
||||||
ans.push_back(token2id_.at("$")); // eos
|
ans.push_back(std::move(this_sentence));
|
||||||
}
|
}
|
||||||
|
|
||||||
return {ans};
|
return ans;
|
||||||
}
|
}
|
||||||
|
|
||||||
void Lexicon::InitTokens(std::istream &is) { token2id_ = ReadTokens(is); }
|
void Lexicon::InitTokens(std::istream &is) { token2id_ = ReadTokens(is); }
|
||||||
@@ -309,16 +332,10 @@ void Lexicon::InitTokens(std::istream &is) { token2id_ = ReadTokens(is); }
|
|||||||
void Lexicon::InitLanguage(const std::string &_lang) {
|
void Lexicon::InitLanguage(const std::string &_lang) {
|
||||||
std::string lang(_lang);
|
std::string lang(_lang);
|
||||||
ToLowerCase(&lang);
|
ToLowerCase(&lang);
|
||||||
if (lang == "english") {
|
if (lang == "chinese") {
|
||||||
language_ = Language::kEnglish;
|
|
||||||
} else if (lang == "german") {
|
|
||||||
language_ = Language::kGerman;
|
|
||||||
} else if (lang == "spanish") {
|
|
||||||
language_ = Language::kSpanish;
|
|
||||||
} else if (lang == "french") {
|
|
||||||
language_ = Language::kFrench;
|
|
||||||
} else if (lang == "chinese") {
|
|
||||||
language_ = Language::kChinese;
|
language_ = Language::kChinese;
|
||||||
|
} else if (!lang.empty()) {
|
||||||
|
language_ = Language::kNotChinese;
|
||||||
} else {
|
} else {
|
||||||
SHERPA_ONNX_LOGE("Unknown language: %s", _lang.c_str());
|
SHERPA_ONNX_LOGE("Unknown language: %s", _lang.c_str());
|
||||||
exit(-1);
|
exit(-1);
|
||||||
|
|||||||
@@ -29,35 +29,19 @@ class Lexicon : public OfflineTtsFrontend {
|
|||||||
// Note: for models from piper, we won't use this class.
|
// Note: for models from piper, we won't use this class.
|
||||||
Lexicon(const std::string &lexicon, const std::string &tokens,
|
Lexicon(const std::string &lexicon, const std::string &tokens,
|
||||||
const std::string &punctuations, const std::string &language,
|
const std::string &punctuations, const std::string &language,
|
||||||
bool debug = false, bool is_piper = false);
|
bool debug = false);
|
||||||
|
|
||||||
#if __ANDROID_API__ >= 9
|
#if __ANDROID_API__ >= 9
|
||||||
Lexicon(AAssetManager *mgr, const std::string &lexicon,
|
Lexicon(AAssetManager *mgr, const std::string &lexicon,
|
||||||
const std::string &tokens, const std::string &punctuations,
|
const std::string &tokens, const std::string &punctuations,
|
||||||
const std::string &language, bool debug = false,
|
const std::string &language, bool debug = false);
|
||||||
bool is_piper = false);
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
std::vector<std::vector<int64_t>> ConvertTextToTokenIds(
|
std::vector<std::vector<int64_t>> ConvertTextToTokenIds(
|
||||||
const std::string &text, const std::string &voice = "") const override;
|
const std::string &text, const std::string &voice = "") const override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::vector<std::vector<int64_t>> ConvertTextToTokenIdsGerman(
|
std::vector<std::vector<int64_t>> ConvertTextToTokenIdsNotChinese(
|
||||||
const std::string &text) const {
|
|
||||||
return ConvertTextToTokenIdsEnglish(text);
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<std::vector<int64_t>> ConvertTextToTokenIdsSpanish(
|
|
||||||
const std::string &text) const {
|
|
||||||
return ConvertTextToTokenIdsEnglish(text);
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<std::vector<int64_t>> ConvertTextToTokenIdsFrench(
|
|
||||||
const std::string &text) const {
|
|
||||||
return ConvertTextToTokenIdsEnglish(text);
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<std::vector<int64_t>> ConvertTextToTokenIdsEnglish(
|
|
||||||
const std::string &text) const;
|
const std::string &text) const;
|
||||||
|
|
||||||
std::vector<std::vector<int64_t>> ConvertTextToTokenIdsChinese(
|
std::vector<std::vector<int64_t>> ConvertTextToTokenIdsChinese(
|
||||||
@@ -70,10 +54,7 @@ class Lexicon : public OfflineTtsFrontend {
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
enum class Language {
|
enum class Language {
|
||||||
kEnglish,
|
kNotChinese,
|
||||||
kGerman,
|
|
||||||
kSpanish,
|
|
||||||
kFrench,
|
|
||||||
kChinese,
|
kChinese,
|
||||||
kUnknown,
|
kUnknown,
|
||||||
};
|
};
|
||||||
@@ -84,7 +65,6 @@ class Lexicon : public OfflineTtsFrontend {
|
|||||||
std::unordered_map<std::string, int32_t> token2id_;
|
std::unordered_map<std::string, int32_t> token2id_;
|
||||||
Language language_;
|
Language language_;
|
||||||
bool debug_;
|
bool debug_;
|
||||||
bool is_piper_;
|
|
||||||
|
|
||||||
// for Chinese polyphones
|
// for Chinese polyphones
|
||||||
std::unique_ptr<std::regex> pattern_;
|
std::unique_ptr<std::regex> pattern_;
|
||||||
|
|||||||
@@ -195,8 +195,7 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
|
|||||||
} else {
|
} else {
|
||||||
frontend_ = std::make_unique<Lexicon>(
|
frontend_ = std::make_unique<Lexicon>(
|
||||||
mgr, config_.model.vits.lexicon, config_.model.vits.tokens,
|
mgr, config_.model.vits.lexicon, config_.model.vits.tokens,
|
||||||
model_->Punctuations(), model_->Language(), config_.model.debug,
|
model_->Punctuations(), model_->Language(), config_.model.debug);
|
||||||
model_->IsPiper());
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
@@ -208,8 +207,7 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
|
|||||||
} else {
|
} else {
|
||||||
frontend_ = std::make_unique<Lexicon>(
|
frontend_ = std::make_unique<Lexicon>(
|
||||||
config_.model.vits.lexicon, config_.model.vits.tokens,
|
config_.model.vits.lexicon, config_.model.vits.tokens,
|
||||||
model_->Punctuations(), model_->Language(), config_.model.debug,
|
model_->Punctuations(), model_->Language(), config_.model.debug);
|
||||||
model_->IsPiper());
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user