Support heteronyms in Chinese TTS (#738)
This commit is contained in:
@@ -164,6 +164,7 @@ endif()
|
||||
|
||||
if(SHERPA_ONNX_ENABLE_TTS)
|
||||
target_link_libraries(sherpa-onnx-core piper_phonemize)
|
||||
target_link_libraries(sherpa-onnx-core fstfar fst)
|
||||
endif()
|
||||
|
||||
if(SHERPA_ONNX_ENABLE_CHECK)
|
||||
|
||||
@@ -18,7 +18,6 @@
|
||||
#endif
|
||||
|
||||
#include <memory>
|
||||
#include <regex> // NOLINT
|
||||
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
@@ -26,6 +25,55 @@
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
static std::vector<std::string> ProcessHeteronyms(
|
||||
const std::vector<std::string> &words) {
|
||||
std::vector<std::string> ans;
|
||||
ans.reserve(words.size());
|
||||
|
||||
int32_t num_words = static_cast<int32_t>(words.size());
|
||||
int32_t i = 0;
|
||||
int32_t prev = -1;
|
||||
while (i < num_words) {
|
||||
// start of a phrase #$|
|
||||
if ((i + 2 < num_words) && words[i] == "#" && words[i + 1] == "$" &&
|
||||
words[i + 2] == "|") {
|
||||
if (prev == -1) {
|
||||
prev = i + 3;
|
||||
}
|
||||
i = i + 3;
|
||||
continue;
|
||||
}
|
||||
|
||||
// end of a phrase |$#
|
||||
if ((i + 2 < num_words) && words[i] == "|" && words[i + 1] == "$" &&
|
||||
words[i + 2] == "#") {
|
||||
if (prev != -1) {
|
||||
std::ostringstream os;
|
||||
for (int32_t k = prev; k < i; ++k) {
|
||||
if (words[k] != "|" && words[k] != "$" && words[k] != "#") {
|
||||
os << words[k];
|
||||
}
|
||||
}
|
||||
ans.push_back(os.str());
|
||||
|
||||
prev = -1;
|
||||
}
|
||||
|
||||
i += 3;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (prev == -1) {
|
||||
// not inside a phrase
|
||||
ans.push_back(words[i]);
|
||||
}
|
||||
|
||||
++i;
|
||||
}
|
||||
|
||||
return ans;
|
||||
}
|
||||
|
||||
static void ToLowerCase(std::string *in_out) {
|
||||
std::transform(in_out->begin(), in_out->end(), in_out->begin(),
|
||||
[](unsigned char c) { return std::tolower(c); });
|
||||
@@ -148,36 +196,9 @@ std::vector<std::vector<int64_t>> Lexicon::ConvertTextToTokenIdsChinese(
|
||||
const std::string &_text) const {
|
||||
std::string text(_text);
|
||||
ToLowerCase(&text);
|
||||
std::vector<std::string> words;
|
||||
if (pattern_) {
|
||||
// Handle polyphones
|
||||
size_t pos = 0;
|
||||
auto begin = std::sregex_iterator(text.begin(), text.end(), *pattern_);
|
||||
auto end = std::sregex_iterator();
|
||||
for (std::sregex_iterator i = begin; i != end; ++i) {
|
||||
std::smatch match = *i;
|
||||
if (pos < match.position()) {
|
||||
auto this_segment = text.substr(pos, match.position() - pos);
|
||||
auto this_segment_words = SplitUtf8(this_segment);
|
||||
words.insert(words.end(), this_segment_words.begin(),
|
||||
this_segment_words.end());
|
||||
pos = match.position() + match.length();
|
||||
} else if (pos == match.position()) {
|
||||
pos = match.position() + match.length();
|
||||
}
|
||||
|
||||
words.push_back(match.str());
|
||||
}
|
||||
|
||||
if (pos < text.size()) {
|
||||
auto this_segment = text.substr(pos, text.size() - pos);
|
||||
auto this_segment_words = SplitUtf8(this_segment);
|
||||
words.insert(words.end(), this_segment_words.begin(),
|
||||
this_segment_words.end());
|
||||
}
|
||||
} else {
|
||||
words = SplitUtf8(text);
|
||||
}
|
||||
std::vector<std::string> words = SplitUtf8(text);
|
||||
words = ProcessHeteronyms(words);
|
||||
|
||||
if (debug_) {
|
||||
fprintf(stderr, "Input text in string: %s\n", text.c_str());
|
||||
@@ -357,9 +378,6 @@ void Lexicon::InitLexicon(std::istream &is) {
|
||||
std::string line;
|
||||
std::string phone;
|
||||
|
||||
std::ostringstream os;
|
||||
std::string sep;
|
||||
|
||||
while (std::getline(is, line)) {
|
||||
std::istringstream iss(line);
|
||||
|
||||
@@ -381,18 +399,9 @@ void Lexicon::InitLexicon(std::istream &is) {
|
||||
if (ids.empty()) {
|
||||
continue;
|
||||
}
|
||||
if (language_ == Language::kChinese && word.size() > 3) {
|
||||
// this is not a single word;
|
||||
os << sep << word;
|
||||
sep = "|";
|
||||
}
|
||||
|
||||
word2ids_.insert({std::move(word), std::move(ids)});
|
||||
}
|
||||
|
||||
if (!sep.empty()) {
|
||||
pattern_ = std::make_unique<std::regex>(os.str());
|
||||
}
|
||||
}
|
||||
|
||||
void Lexicon::InitPunctuations(const std::string &punctuations) {
|
||||
|
||||
@@ -7,7 +7,6 @@
|
||||
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
#include <regex> // NOLINT
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
@@ -65,9 +64,6 @@ class Lexicon : public OfflineTtsFrontend {
|
||||
std::unordered_map<std::string, int32_t> token2id_;
|
||||
Language language_;
|
||||
bool debug_;
|
||||
|
||||
// for Chinese polyphones
|
||||
std::unique_ptr<std::regex> pattern_;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
@@ -15,6 +15,9 @@
|
||||
#include "android/asset_manager.h"
|
||||
#include "android/asset_manager_jni.h"
|
||||
#endif
|
||||
|
||||
#include "fst/extensions/far/far.h"
|
||||
#include "kaldifst/csrc/kaldi-fst-io.h"
|
||||
#include "kaldifst/csrc/text-normalizer.h"
|
||||
#include "sherpa-onnx/csrc/lexicon.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
@@ -46,6 +49,32 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
|
||||
tn_list_.push_back(std::make_unique<kaldifst::TextNormalizer>(f));
|
||||
}
|
||||
}
|
||||
|
||||
if (!config.rule_fars.empty()) {
|
||||
if (config.model.debug) {
|
||||
SHERPA_ONNX_LOGE("Loading FST archives");
|
||||
}
|
||||
std::vector<std::string> files;
|
||||
SplitStringToVector(config.rule_fars, ",", false, &files);
|
||||
for (const auto &f : files) {
|
||||
if (config.model.debug) {
|
||||
SHERPA_ONNX_LOGE("rule far: %s", f.c_str());
|
||||
}
|
||||
std::unique_ptr<fst::FarReader<fst::StdArc>> reader(
|
||||
fst::FarReader<fst::StdArc>::Open(f));
|
||||
for (; !reader->Done(); reader->Next()) {
|
||||
std::unique_ptr<fst::StdConstFst> r(
|
||||
fst::CastOrConvertToConstFst(reader->GetFst()->Copy()));
|
||||
|
||||
tn_list_.push_back(
|
||||
std::make_unique<kaldifst::TextNormalizer>(std::move(r)));
|
||||
}
|
||||
}
|
||||
|
||||
if (config.model.debug) {
|
||||
SHERPA_ONNX_LOGE("FST archives loaded!");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
|
||||
@@ -20,7 +20,14 @@ void OfflineTtsConfig::Register(ParseOptions *po) {
|
||||
"It not empty, it contains a list of rule FST filenames."
|
||||
"Multiple filenames are separated by a comma and they are "
|
||||
"applied from left to right. An example value: "
|
||||
"rule1.fst,rule2,fst,rule3.fst");
|
||||
"rule1.fst,rule2.fst,rule3.fst");
|
||||
|
||||
po->Register("tts-rule-fars", &rule_fars,
|
||||
"It not empty, it contains a list of rule FST archive filenames."
|
||||
"Multiple filenames are separated by a comma and they are "
|
||||
"applied from left to right. An example value: "
|
||||
"rule1.far,rule2.far,rule3.far. Note that an *.far can contain "
|
||||
"multiple *.fst files");
|
||||
|
||||
po->Register(
|
||||
"tts-max-num-sentences", &max_num_sentences,
|
||||
@@ -41,6 +48,17 @@ bool OfflineTtsConfig::Validate() const {
|
||||
}
|
||||
}
|
||||
|
||||
if (!rule_fars.empty()) {
|
||||
std::vector<std::string> files;
|
||||
SplitStringToVector(rule_fars, ",", false, &files);
|
||||
for (const auto &f : files) {
|
||||
if (!FileExists(f)) {
|
||||
SHERPA_ONNX_LOGE("Rule far %s does not exist. ", f.c_str());
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return model.Validate();
|
||||
}
|
||||
|
||||
@@ -50,6 +68,7 @@ std::string OfflineTtsConfig::ToString() const {
|
||||
os << "OfflineTtsConfig(";
|
||||
os << "model=" << model.ToString() << ", ";
|
||||
os << "rule_fsts=\"" << rule_fsts << "\", ";
|
||||
os << "rule_fars=\"" << rule_fars << "\", ";
|
||||
os << "max_num_sentences=" << max_num_sentences << ")";
|
||||
|
||||
return os.str();
|
||||
|
||||
@@ -29,6 +29,9 @@ struct OfflineTtsConfig {
|
||||
// If there are multiple rules, they are applied from left to right.
|
||||
std::string rule_fsts;
|
||||
|
||||
// If there are multiple FST archives, they are applied from left to right.
|
||||
std::string rule_fars;
|
||||
|
||||
// Maximum number of sentences that we process at a time.
|
||||
// This is to avoid OOM for very long input text.
|
||||
// If you set it to -1, then we process all sentences in a single batch.
|
||||
@@ -36,9 +39,11 @@ struct OfflineTtsConfig {
|
||||
|
||||
OfflineTtsConfig() = default;
|
||||
OfflineTtsConfig(const OfflineTtsModelConfig &model,
|
||||
const std::string &rule_fsts, int32_t max_num_sentences)
|
||||
const std::string &rule_fsts, const std::string &rule_fars,
|
||||
int32_t max_num_sentences)
|
||||
: model(model),
|
||||
rule_fsts(rule_fsts),
|
||||
rule_fars(rule_fars),
|
||||
max_num_sentences(max_num_sentences) {}
|
||||
|
||||
void Register(ParseOptions *po);
|
||||
|
||||
Reference in New Issue
Block a user