Support Chinese vits models (#368)

This commit is contained in:
Fangjun Kuang
2023-10-18 10:19:10 +08:00
committed by GitHub
parent 9efe69720d
commit 1ee79e3ff5
16 changed files with 326 additions and 62 deletions

View File

@@ -76,9 +76,105 @@ static std::vector<int32_t> ConvertTokensToIds(
}
Lexicon::Lexicon(const std::string &lexicon, const std::string &tokens,
const std::string &punctuations) {
const std::string &punctuations, const std::string &language) {
InitLanguage(language);
InitTokens(tokens);
InitLexicon(lexicon);
InitPunctuations(punctuations);
}
std::vector<int64_t> Lexicon::ConvertTextToTokenIds(
const std::string &text) const {
switch (language_) {
case Language::kEnglish:
return ConvertTextToTokenIdsEnglish(text);
case Language::kChinese:
return ConvertTextToTokenIdsChinese(text);
default:
SHERPA_ONNX_LOGE("Unknonw language: %d", static_cast<int32_t>(language_));
exit(-1);
}
return {};
}
std::vector<int64_t> Lexicon::ConvertTextToTokenIdsChinese(
const std::string &text) const {
std::vector<std::string> words = SplitUtf8(text);
std::vector<int64_t> ans;
ans.push_back(token2id_.at("sil"));
for (const auto &w : words) {
if (!word2ids_.count(w)) {
SHERPA_ONNX_LOGE("OOV %s. Ignore it!", w.c_str());
continue;
}
const auto &token_ids = word2ids_.at(w);
ans.insert(ans.end(), token_ids.begin(), token_ids.end());
}
ans.push_back(token2id_.at("sil"));
ans.push_back(token2id_.at("eos"));
return ans;
}
std::vector<int64_t> Lexicon::ConvertTextToTokenIdsEnglish(
const std::string &_text) const {
std::string text(_text);
ToLowerCase(&text);
std::vector<std::string> words = SplitUtf8(text);
std::vector<int64_t> ans;
for (const auto &w : words) {
if (punctuations_.count(w)) {
ans.push_back(token2id_.at(w));
continue;
}
if (!word2ids_.count(w)) {
SHERPA_ONNX_LOGE("OOV %s. Ignore it!", w.c_str());
continue;
}
const auto &token_ids = word2ids_.at(w);
ans.insert(ans.end(), token_ids.begin(), token_ids.end());
if (blank_ != -1) {
ans.push_back(blank_);
}
}
if (blank_ != -1 && !ans.empty()) {
// remove the last blank
ans.resize(ans.size() - 1);
}
return ans;
}
void Lexicon::InitTokens(const std::string &tokens) {
token2id_ = ReadTokens(tokens);
blank_ = token2id_.at(" ");
if (token2id_.count(" ")) {
blank_ = token2id_.at(" ");
}
}
void Lexicon::InitLanguage(const std::string &_lang) {
std::string lang(_lang);
ToLowerCase(&lang);
if (lang == "english") {
language_ = Language::kEnglish;
} else if (lang == "chinese") {
language_ = Language::kChinese;
} else {
SHERPA_ONNX_LOGE("Unknown language: %s", _lang.c_str());
exit(-1);
}
}
void Lexicon::InitLexicon(const std::string &lexicon) {
std::ifstream is(lexicon);
std::string word;
@@ -109,8 +205,9 @@ Lexicon::Lexicon(const std::string &lexicon, const std::string &tokens,
}
word2ids_.insert({std::move(word), std::move(ids)});
}
}
// process punctuations
void Lexicon::InitPunctuations(const std::string &punctuations) {
std::vector<std::string> punctuation_list;
SplitStringToVector(punctuations, " ", false, &punctuation_list);
for (auto &s : punctuation_list) {
@@ -118,46 +215,4 @@ Lexicon::Lexicon(const std::string &lexicon, const std::string &tokens,
}
}
std::vector<int64_t> Lexicon::ConvertTextToTokenIds(
const std::string &_text) const {
std::string text(_text);
ToLowerCase(&text);
std::vector<std::string> words;
SplitStringToVector(text, " ", false, &words);
std::vector<int64_t> ans;
for (auto w : words) {
std::vector<int64_t> prefix;
while (!w.empty() && punctuations_.count(std::string(1, w[0]))) {
// if w begins with a punctuation
prefix.push_back(token2id_.at(std::string(1, w[0])));
w = std::string(w.begin() + 1, w.end());
}
std::vector<int64_t> suffix;
while (!w.empty() && punctuations_.count(std::string(1, w.back()))) {
suffix.push_back(token2id_.at(std::string(1, w.back())));
w = std::string(w.begin(), w.end() - 1);
}
if (!word2ids_.count(w)) {
SHERPA_ONNX_LOGE("OOV %s. Ignore it!", w.c_str());
continue;
}
const auto &token_ids = word2ids_.at(w);
ans.insert(ans.end(), prefix.begin(), prefix.end());
ans.insert(ans.end(), token_ids.begin(), token_ids.end());
ans.insert(ans.end(), suffix.rbegin(), suffix.rend());
ans.push_back(blank_);
}
if (!ans.empty()) {
ans.resize(ans.size() - 1);
}
return ans;
}
} // namespace sherpa_onnx