This repository has been archived on 2025-08-26. You can view files and clone it, but cannot push or open issues or pull requests.
Files
enginex_bi_series-sherpa-onnx/sherpa-onnx/csrc/melo-tts-lexicon.cc
2024-10-24 11:15:08 +08:00

277 lines
7.8 KiB
C++
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

// sherpa-onnx/csrc/melo-tts-lexicon.cc
//
// Copyright (c) 2022-2024 Xiaomi Corporation
#include "sherpa-onnx/csrc/melo-tts-lexicon.h"
#include <fstream>
#include <regex> // NOLINT
#include <utility>
#include "cppjieba/Jieba.hpp"
#include "sherpa-onnx/csrc/file-utils.h"
#include "sherpa-onnx/csrc/lexicon.h"
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/text-utils.h"
namespace sherpa_onnx {
class MeloTtsLexicon::Impl {
public:
Impl(const std::string &lexicon, const std::string &tokens,
const std::string &dict_dir,
const OfflineTtsVitsModelMetaData &meta_data, bool debug)
: meta_data_(meta_data), debug_(debug) {
std::string dict = dict_dir + "/jieba.dict.utf8";
std::string hmm = dict_dir + "/hmm_model.utf8";
std::string user_dict = dict_dir + "/user.dict.utf8";
std::string idf = dict_dir + "/idf.utf8";
std::string stop_word = dict_dir + "/stop_words.utf8";
AssertFileExists(dict);
AssertFileExists(hmm);
AssertFileExists(user_dict);
AssertFileExists(idf);
AssertFileExists(stop_word);
jieba_ =
std::make_unique<cppjieba::Jieba>(dict, hmm, user_dict, idf, stop_word);
{
std::ifstream is(tokens);
InitTokens(is);
}
{
std::ifstream is(lexicon);
InitLexicon(is);
}
}
std::vector<TokenIDs> ConvertTextToTokenIds(const std::string &_text) const {
std::string text = ToLowerCase(_text);
// see
// https://github.com/Plachtaa/VITS-fast-fine-tuning/blob/main/text/mandarin.py#L244
std::regex punct_re{"|、|"};
std::string s = std::regex_replace(text, punct_re, ",");
std::regex punct_re2("");
s = std::regex_replace(s, punct_re2, ".");
std::regex punct_re3("");
s = std::regex_replace(s, punct_re3, "?");
std::regex punct_re4("");
s = std::regex_replace(s, punct_re4, "!");
std::vector<std::string> words;
bool is_hmm = true;
jieba_->Cut(text, words, is_hmm);
if (debug_) {
SHERPA_ONNX_LOGE("input text: %s", text.c_str());
SHERPA_ONNX_LOGE("after replacing punctuations: %s", s.c_str());
std::ostringstream os;
std::string sep = "";
for (const auto &w : words) {
os << sep << w;
sep = "_";
}
SHERPA_ONNX_LOGE("after jieba processing: %s", os.str().c_str());
}
std::vector<TokenIDs> ans;
TokenIDs this_sentence;
for (const auto &w : words) {
auto ids = ConvertWordToIds(w);
if (ids.tokens.empty()) {
SHERPA_ONNX_LOGE("Ignore OOV '%s'", w.c_str());
continue;
}
this_sentence.tokens.insert(this_sentence.tokens.end(),
ids.tokens.begin(), ids.tokens.end());
this_sentence.tones.insert(this_sentence.tones.end(), ids.tones.begin(),
ids.tones.end());
if (w == "." || w == "!" || w == "?" || w == "," || w == "" ||
w == "" || w == "" || w == "") {
ans.push_back(std::move(this_sentence));
this_sentence = {};
}
} // for (const auto &w : words)
if (!this_sentence.tokens.empty()) {
ans.push_back(std::move(this_sentence));
}
return ans;
}
private:
TokenIDs ConvertWordToIds(const std::string &w) const {
if (word2ids_.count(w)) {
return word2ids_.at(w);
}
if (token2id_.count(w)) {
return {{token2id_.at(w)}, {0}};
}
TokenIDs ans;
std::vector<std::string> words = SplitUtf8(w);
for (const auto &word : words) {
if (word2ids_.count(word)) {
auto ids = ConvertWordToIds(word);
ans.tokens.insert(ans.tokens.end(), ids.tokens.begin(),
ids.tokens.end());
ans.tones.insert(ans.tones.end(), ids.tones.begin(), ids.tones.end());
} else {
// If the lexicon does not contain the word, we split the word into
// characters.
//
// For instance, if the word is TTS and it is does not exist
// in the lexicon, we split it into 3 characters: T T S
std::string s;
for (char c : word) {
s = c;
if (word2ids_.count(s)) {
const auto &t = word2ids_.at(s);
ans.tokens.insert(ans.tokens.end(), t.tokens.begin(),
t.tokens.end());
ans.tones.insert(ans.tones.end(), t.tones.begin(), t.tones.end());
}
}
}
}
return ans;
}
void InitTokens(std::istream &is) {
token2id_ = ReadTokens(is);
token2id_[" "] = token2id_["_"];
std::vector<std::pair<std::string, std::string>> puncts = {
{",", ""}, {".", ""}, {"!", ""}, {"?", ""}};
for (const auto &p : puncts) {
if (token2id_.count(p.first) && !token2id_.count(p.second)) {
token2id_[p.second] = token2id_[p.first];
}
if (!token2id_.count(p.first) && token2id_.count(p.second)) {
token2id_[p.first] = token2id_[p.second];
}
}
if (!token2id_.count("") && token2id_.count("")) {
token2id_[""] = token2id_[""];
}
}
void InitLexicon(std::istream &is) {
std::string word;
std::vector<std::string> token_list;
std::vector<std::string> phone_list;
std::vector<int64_t> tone_list;
std::string line;
std::string phone;
int32_t line_num = 0;
while (std::getline(is, line)) {
++line_num;
std::istringstream iss(line);
token_list.clear();
phone_list.clear();
tone_list.clear();
iss >> word;
ToLowerCase(&word);
if (word2ids_.count(word)) {
SHERPA_ONNX_LOGE("Duplicated word: %s at line %d:%s. Ignore it.",
word.c_str(), line_num, line.c_str());
continue;
}
while (iss >> phone) {
token_list.push_back(std::move(phone));
}
if ((token_list.size() & 1) != 0) {
SHERPA_ONNX_LOGE("Invalid line %d: '%s'", line_num, line.c_str());
exit(-1);
}
int32_t num_phones = token_list.size() / 2;
phone_list.reserve(num_phones);
tone_list.reserve(num_phones);
for (int32_t i = 0; i != num_phones; ++i) {
phone_list.push_back(std::move(token_list[i]));
tone_list.push_back(std::stoi(token_list[i + num_phones], nullptr));
if (tone_list.back() < 0 || tone_list.back() > 50) {
SHERPA_ONNX_LOGE("Invalid line %d: '%s'", line_num, line.c_str());
exit(-1);
}
}
std::vector<int32_t> ids = ConvertTokensToIds(token2id_, phone_list);
if (ids.empty()) {
continue;
}
if (ids.size() != num_phones) {
SHERPA_ONNX_LOGE("Invalid line %d: '%s'", line_num, line.c_str());
exit(-1);
}
std::vector<int64_t> ids64{ids.begin(), ids.end()};
word2ids_.insert(
{std::move(word), TokenIDs{std::move(ids64), std::move(tone_list)}});
}
word2ids_[""] = word2ids_[""];
word2ids_[""] = word2ids_[""];
}
private:
// lexicon.txt is saved in word2ids_
std::unordered_map<std::string, TokenIDs> word2ids_;
// tokens.txt is saved in token2id_
std::unordered_map<std::string, int32_t> token2id_;
OfflineTtsVitsModelMetaData meta_data_;
std::unique_ptr<cppjieba::Jieba> jieba_;
bool debug_ = false;
};
MeloTtsLexicon::~MeloTtsLexicon() = default;
MeloTtsLexicon::MeloTtsLexicon(const std::string &lexicon,
const std::string &tokens,
const std::string &dict_dir,
const OfflineTtsVitsModelMetaData &meta_data,
bool debug)
: impl_(std::make_unique<Impl>(lexicon, tokens, dict_dir, meta_data,
debug)) {}
std::vector<TokenIDs> MeloTtsLexicon::ConvertTextToTokenIds(
const std::string &text, const std::string & /*unused_voice = ""*/) const {
return impl_->ConvertTextToTokenIds(text);
}
} // namespace sherpa_onnx