Support heteronyms in Chinese TTS (#738)

This commit is contained in:
Fangjun Kuang
2024-04-08 11:01:30 +08:00
committed by GitHub
parent c1c0f5bafd
commit a5f8fbc83f
49 changed files with 308 additions and 143 deletions

View File

@@ -818,6 +818,7 @@ SherpaOnnxOfflineTts *SherpaOnnxCreateOfflineTts(
tts_config.model.debug = config->model.debug;
tts_config.model.provider = SHERPA_ONNX_OR(config->model.provider, "cpu");
tts_config.rule_fsts = SHERPA_ONNX_OR(config->rule_fsts, "");
tts_config.rule_fars = SHERPA_ONNX_OR(config->rule_fars, "");
tts_config.max_num_sentences = SHERPA_ONNX_OR(config->max_num_sentences, 2);
if (tts_config.model.debug) {

View File

@@ -783,6 +783,7 @@ SHERPA_ONNX_API typedef struct SherpaOnnxOfflineTtsConfig {
SherpaOnnxOfflineTtsModelConfig model;
const char *rule_fsts;
int32_t max_num_sentences;
const char *rule_fars;
} SherpaOnnxOfflineTtsConfig;
SHERPA_ONNX_API typedef struct SherpaOnnxGeneratedAudio {

View File

@@ -164,6 +164,7 @@ endif()
if(SHERPA_ONNX_ENABLE_TTS)
target_link_libraries(sherpa-onnx-core piper_phonemize)
target_link_libraries(sherpa-onnx-core fstfar fst)
endif()
if(SHERPA_ONNX_ENABLE_CHECK)

View File

@@ -18,7 +18,6 @@
#endif
#include <memory>
#include <regex> // NOLINT
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
@@ -26,6 +25,55 @@
namespace sherpa_onnx {
static std::vector<std::string> ProcessHeteronyms(
const std::vector<std::string> &words) {
std::vector<std::string> ans;
ans.reserve(words.size());
int32_t num_words = static_cast<int32_t>(words.size());
int32_t i = 0;
int32_t prev = -1;
while (i < num_words) {
// start of a phrase #$|
if ((i + 2 < num_words) && words[i] == "#" && words[i + 1] == "$" &&
words[i + 2] == "|") {
if (prev == -1) {
prev = i + 3;
}
i = i + 3;
continue;
}
// end of a phrase |$#
if ((i + 2 < num_words) && words[i] == "|" && words[i + 1] == "$" &&
words[i + 2] == "#") {
if (prev != -1) {
std::ostringstream os;
for (int32_t k = prev; k < i; ++k) {
if (words[k] != "|" && words[k] != "$" && words[k] != "#") {
os << words[k];
}
}
ans.push_back(os.str());
prev = -1;
}
i += 3;
continue;
}
if (prev == -1) {
// not inside a phrase
ans.push_back(words[i]);
}
++i;
}
return ans;
}
static void ToLowerCase(std::string *in_out) {
std::transform(in_out->begin(), in_out->end(), in_out->begin(),
[](unsigned char c) { return std::tolower(c); });
@@ -148,36 +196,9 @@ std::vector<std::vector<int64_t>> Lexicon::ConvertTextToTokenIdsChinese(
const std::string &_text) const {
std::string text(_text);
ToLowerCase(&text);
std::vector<std::string> words;
if (pattern_) {
// Handle polyphones
size_t pos = 0;
auto begin = std::sregex_iterator(text.begin(), text.end(), *pattern_);
auto end = std::sregex_iterator();
for (std::sregex_iterator i = begin; i != end; ++i) {
std::smatch match = *i;
if (pos < match.position()) {
auto this_segment = text.substr(pos, match.position() - pos);
auto this_segment_words = SplitUtf8(this_segment);
words.insert(words.end(), this_segment_words.begin(),
this_segment_words.end());
pos = match.position() + match.length();
} else if (pos == match.position()) {
pos = match.position() + match.length();
}
words.push_back(match.str());
}
if (pos < text.size()) {
auto this_segment = text.substr(pos, text.size() - pos);
auto this_segment_words = SplitUtf8(this_segment);
words.insert(words.end(), this_segment_words.begin(),
this_segment_words.end());
}
} else {
words = SplitUtf8(text);
}
std::vector<std::string> words = SplitUtf8(text);
words = ProcessHeteronyms(words);
if (debug_) {
fprintf(stderr, "Input text in string: %s\n", text.c_str());
@@ -357,9 +378,6 @@ void Lexicon::InitLexicon(std::istream &is) {
std::string line;
std::string phone;
std::ostringstream os;
std::string sep;
while (std::getline(is, line)) {
std::istringstream iss(line);
@@ -381,18 +399,9 @@ void Lexicon::InitLexicon(std::istream &is) {
if (ids.empty()) {
continue;
}
if (language_ == Language::kChinese && word.size() > 3) {
// this is not a single word;
os << sep << word;
sep = "|";
}
word2ids_.insert({std::move(word), std::move(ids)});
}
if (!sep.empty()) {
pattern_ = std::make_unique<std::regex>(os.str());
}
}
void Lexicon::InitPunctuations(const std::string &punctuations) {

View File

@@ -7,7 +7,6 @@
#include <cstdint>
#include <memory>
#include <regex> // NOLINT
#include <string>
#include <unordered_map>
#include <unordered_set>
@@ -65,9 +64,6 @@ class Lexicon : public OfflineTtsFrontend {
std::unordered_map<std::string, int32_t> token2id_;
Language language_;
bool debug_;
// for Chinese polyphones
std::unique_ptr<std::regex> pattern_;
};
} // namespace sherpa_onnx

View File

@@ -15,6 +15,9 @@
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#include "fst/extensions/far/far.h"
#include "kaldifst/csrc/kaldi-fst-io.h"
#include "kaldifst/csrc/text-normalizer.h"
#include "sherpa-onnx/csrc/lexicon.h"
#include "sherpa-onnx/csrc/macros.h"
@@ -46,6 +49,32 @@ class OfflineTtsVitsImpl : public OfflineTtsImpl {
tn_list_.push_back(std::make_unique<kaldifst::TextNormalizer>(f));
}
}
if (!config.rule_fars.empty()) {
if (config.model.debug) {
SHERPA_ONNX_LOGE("Loading FST archives");
}
std::vector<std::string> files;
SplitStringToVector(config.rule_fars, ",", false, &files);
for (const auto &f : files) {
if (config.model.debug) {
SHERPA_ONNX_LOGE("rule far: %s", f.c_str());
}
std::unique_ptr<fst::FarReader<fst::StdArc>> reader(
fst::FarReader<fst::StdArc>::Open(f));
for (; !reader->Done(); reader->Next()) {
std::unique_ptr<fst::StdConstFst> r(
fst::CastOrConvertToConstFst(reader->GetFst()->Copy()));
tn_list_.push_back(
std::make_unique<kaldifst::TextNormalizer>(std::move(r)));
}
}
if (config.model.debug) {
SHERPA_ONNX_LOGE("FST archives loaded!");
}
}
}
#if __ANDROID_API__ >= 9

View File

@@ -20,7 +20,14 @@ void OfflineTtsConfig::Register(ParseOptions *po) {
"It not empty, it contains a list of rule FST filenames."
"Multiple filenames are separated by a comma and they are "
"applied from left to right. An example value: "
"rule1.fst,rule2,fst,rule3.fst");
"rule1.fst,rule2.fst,rule3.fst");
po->Register("tts-rule-fars", &rule_fars,
"It not empty, it contains a list of rule FST archive filenames."
"Multiple filenames are separated by a comma and they are "
"applied from left to right. An example value: "
"rule1.far,rule2.far,rule3.far. Note that an *.far can contain "
"multiple *.fst files");
po->Register(
"tts-max-num-sentences", &max_num_sentences,
@@ -41,6 +48,17 @@ bool OfflineTtsConfig::Validate() const {
}
}
if (!rule_fars.empty()) {
std::vector<std::string> files;
SplitStringToVector(rule_fars, ",", false, &files);
for (const auto &f : files) {
if (!FileExists(f)) {
SHERPA_ONNX_LOGE("Rule far %s does not exist. ", f.c_str());
return false;
}
}
}
return model.Validate();
}
@@ -50,6 +68,7 @@ std::string OfflineTtsConfig::ToString() const {
os << "OfflineTtsConfig(";
os << "model=" << model.ToString() << ", ";
os << "rule_fsts=\"" << rule_fsts << "\", ";
os << "rule_fars=\"" << rule_fars << "\", ";
os << "max_num_sentences=" << max_num_sentences << ")";
return os.str();

View File

@@ -29,6 +29,9 @@ struct OfflineTtsConfig {
// If there are multiple rules, they are applied from left to right.
std::string rule_fsts;
// If there are multiple FST archives, they are applied from left to right.
std::string rule_fars;
// Maximum number of sentences that we process at a time.
// This is to avoid OOM for very long input text.
// If you set it to -1, then we process all sentences in a single batch.
@@ -36,9 +39,11 @@ struct OfflineTtsConfig {
OfflineTtsConfig() = default;
OfflineTtsConfig(const OfflineTtsModelConfig &model,
const std::string &rule_fsts, int32_t max_num_sentences)
const std::string &rule_fsts, const std::string &rule_fars,
int32_t max_num_sentences)
: model(model),
rule_fsts(rule_fsts),
rule_fars(rule_fars),
max_num_sentences(max_num_sentences) {}
void Register(ParseOptions *po);

View File

@@ -878,6 +878,13 @@ static OfflineTtsConfig GetOfflineTtsConfig(JNIEnv *env, jobject config) {
ans.rule_fsts = p;
env->ReleaseStringUTFChars(s, p);
// for ruleFars
fid = env->GetFieldID(cls, "ruleFars", "Ljava/lang/String;");
s = (jstring)env->GetObjectField(config, fid);
p = env->GetStringUTFChars(s, nullptr);
ans.rule_fars = p;
env->ReleaseStringUTFChars(s, p);
fid = env->GetFieldID(cls, "maxNumSentences", "I");
ans.max_num_sentences = env->GetIntField(config, fid);

View File

@@ -32,11 +32,12 @@ static void PybindOfflineTtsConfig(py::module *m) {
py::class_<PyClass>(*m, "OfflineTtsConfig")
.def(py::init<>())
.def(py::init<const OfflineTtsModelConfig &, const std::string &,
int32_t>(),
const std::string &, int32_t>(),
py::arg("model"), py::arg("rule_fsts") = "",
py::arg("max_num_sentences") = 2)
py::arg("rule_fars") = "", py::arg("max_num_sentences") = 2)
.def_readwrite("model", &PyClass::model)
.def_readwrite("rule_fsts", &PyClass::rule_fsts)
.def_readwrite("rule_fars", &PyClass::rule_fars)
.def_readwrite("max_num_sentences", &PyClass::max_num_sentences)
.def("validate", &PyClass::Validate)
.def("__str__", &PyClass::ToString);