Support heteronyms in Chinese TTS (#738)

This commit is contained in:
Fangjun Kuang
2024-04-08 11:01:30 +08:00
committed by GitHub
parent c1c0f5bafd
commit a5f8fbc83f
49 changed files with 308 additions and 143 deletions

View File

@@ -18,7 +18,6 @@
#endif
#include <memory>
#include <regex> // NOLINT
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
@@ -26,6 +25,55 @@
namespace sherpa_onnx {
static std::vector<std::string> ProcessHeteronyms(
const std::vector<std::string> &words) {
std::vector<std::string> ans;
ans.reserve(words.size());
int32_t num_words = static_cast<int32_t>(words.size());
int32_t i = 0;
int32_t prev = -1;
while (i < num_words) {
// start of a phrase #$|
if ((i + 2 < num_words) && words[i] == "#" && words[i + 1] == "$" &&
words[i + 2] == "|") {
if (prev == -1) {
prev = i + 3;
}
i = i + 3;
continue;
}
// end of a phrase |$#
if ((i + 2 < num_words) && words[i] == "|" && words[i + 1] == "$" &&
words[i + 2] == "#") {
if (prev != -1) {
std::ostringstream os;
for (int32_t k = prev; k < i; ++k) {
if (words[k] != "|" && words[k] != "$" && words[k] != "#") {
os << words[k];
}
}
ans.push_back(os.str());
prev = -1;
}
i += 3;
continue;
}
if (prev == -1) {
// not inside a phrase
ans.push_back(words[i]);
}
++i;
}
return ans;
}
static void ToLowerCase(std::string *in_out) {
std::transform(in_out->begin(), in_out->end(), in_out->begin(),
[](unsigned char c) { return std::tolower(c); });
@@ -148,36 +196,9 @@ std::vector<std::vector<int64_t>> Lexicon::ConvertTextToTokenIdsChinese(
const std::string &_text) const {
std::string text(_text);
ToLowerCase(&text);
std::vector<std::string> words;
if (pattern_) {
// Handle polyphones
size_t pos = 0;
auto begin = std::sregex_iterator(text.begin(), text.end(), *pattern_);
auto end = std::sregex_iterator();
for (std::sregex_iterator i = begin; i != end; ++i) {
std::smatch match = *i;
if (pos < match.position()) {
auto this_segment = text.substr(pos, match.position() - pos);
auto this_segment_words = SplitUtf8(this_segment);
words.insert(words.end(), this_segment_words.begin(),
this_segment_words.end());
pos = match.position() + match.length();
} else if (pos == match.position()) {
pos = match.position() + match.length();
}
words.push_back(match.str());
}
if (pos < text.size()) {
auto this_segment = text.substr(pos, text.size() - pos);
auto this_segment_words = SplitUtf8(this_segment);
words.insert(words.end(), this_segment_words.begin(),
this_segment_words.end());
}
} else {
words = SplitUtf8(text);
}
std::vector<std::string> words = SplitUtf8(text);
words = ProcessHeteronyms(words);
if (debug_) {
fprintf(stderr, "Input text in string: %s\n", text.c_str());
@@ -357,9 +378,6 @@ void Lexicon::InitLexicon(std::istream &is) {
std::string line;
std::string phone;
std::ostringstream os;
std::string sep;
while (std::getline(is, line)) {
std::istringstream iss(line);
@@ -381,18 +399,9 @@ void Lexicon::InitLexicon(std::istream &is) {
if (ids.empty()) {
continue;
}
if (language_ == Language::kChinese && word.size() > 3) {
// this is not a single word;
os << sep << word;
sep = "|";
}
word2ids_.insert({std::move(word), std::move(ids)});
}
if (!sep.empty()) {
pattern_ = std::make_unique<std::regex>(os.str());
}
}
void Lexicon::InitPunctuations(const std::string &punctuations) {