Support Chinese vits models (#368)
This commit is contained in:
@@ -76,9 +76,105 @@ static std::vector<int32_t> ConvertTokensToIds(
|
||||
}
|
||||
|
||||
Lexicon::Lexicon(const std::string &lexicon, const std::string &tokens,
|
||||
const std::string &punctuations) {
|
||||
const std::string &punctuations, const std::string &language) {
|
||||
InitLanguage(language);
|
||||
InitTokens(tokens);
|
||||
InitLexicon(lexicon);
|
||||
InitPunctuations(punctuations);
|
||||
}
|
||||
|
||||
std::vector<int64_t> Lexicon::ConvertTextToTokenIds(
|
||||
const std::string &text) const {
|
||||
switch (language_) {
|
||||
case Language::kEnglish:
|
||||
return ConvertTextToTokenIdsEnglish(text);
|
||||
case Language::kChinese:
|
||||
return ConvertTextToTokenIdsChinese(text);
|
||||
default:
|
||||
SHERPA_ONNX_LOGE("Unknonw language: %d", static_cast<int32_t>(language_));
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
return {};
|
||||
}
|
||||
|
||||
std::vector<int64_t> Lexicon::ConvertTextToTokenIdsChinese(
|
||||
const std::string &text) const {
|
||||
std::vector<std::string> words = SplitUtf8(text);
|
||||
|
||||
std::vector<int64_t> ans;
|
||||
|
||||
ans.push_back(token2id_.at("sil"));
|
||||
|
||||
for (const auto &w : words) {
|
||||
if (!word2ids_.count(w)) {
|
||||
SHERPA_ONNX_LOGE("OOV %s. Ignore it!", w.c_str());
|
||||
continue;
|
||||
}
|
||||
|
||||
const auto &token_ids = word2ids_.at(w);
|
||||
ans.insert(ans.end(), token_ids.begin(), token_ids.end());
|
||||
}
|
||||
ans.push_back(token2id_.at("sil"));
|
||||
ans.push_back(token2id_.at("eos"));
|
||||
return ans;
|
||||
}
|
||||
|
||||
std::vector<int64_t> Lexicon::ConvertTextToTokenIdsEnglish(
|
||||
const std::string &_text) const {
|
||||
std::string text(_text);
|
||||
ToLowerCase(&text);
|
||||
|
||||
std::vector<std::string> words = SplitUtf8(text);
|
||||
|
||||
std::vector<int64_t> ans;
|
||||
for (const auto &w : words) {
|
||||
if (punctuations_.count(w)) {
|
||||
ans.push_back(token2id_.at(w));
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!word2ids_.count(w)) {
|
||||
SHERPA_ONNX_LOGE("OOV %s. Ignore it!", w.c_str());
|
||||
continue;
|
||||
}
|
||||
|
||||
const auto &token_ids = word2ids_.at(w);
|
||||
ans.insert(ans.end(), token_ids.begin(), token_ids.end());
|
||||
if (blank_ != -1) {
|
||||
ans.push_back(blank_);
|
||||
}
|
||||
}
|
||||
|
||||
if (blank_ != -1 && !ans.empty()) {
|
||||
// remove the last blank
|
||||
ans.resize(ans.size() - 1);
|
||||
}
|
||||
|
||||
return ans;
|
||||
}
|
||||
|
||||
void Lexicon::InitTokens(const std::string &tokens) {
|
||||
token2id_ = ReadTokens(tokens);
|
||||
blank_ = token2id_.at(" ");
|
||||
if (token2id_.count(" ")) {
|
||||
blank_ = token2id_.at(" ");
|
||||
}
|
||||
}
|
||||
|
||||
void Lexicon::InitLanguage(const std::string &_lang) {
|
||||
std::string lang(_lang);
|
||||
ToLowerCase(&lang);
|
||||
if (lang == "english") {
|
||||
language_ = Language::kEnglish;
|
||||
} else if (lang == "chinese") {
|
||||
language_ = Language::kChinese;
|
||||
} else {
|
||||
SHERPA_ONNX_LOGE("Unknown language: %s", _lang.c_str());
|
||||
exit(-1);
|
||||
}
|
||||
}
|
||||
|
||||
void Lexicon::InitLexicon(const std::string &lexicon) {
|
||||
std::ifstream is(lexicon);
|
||||
|
||||
std::string word;
|
||||
@@ -109,8 +205,9 @@ Lexicon::Lexicon(const std::string &lexicon, const std::string &tokens,
|
||||
}
|
||||
word2ids_.insert({std::move(word), std::move(ids)});
|
||||
}
|
||||
}
|
||||
|
||||
// process punctuations
|
||||
void Lexicon::InitPunctuations(const std::string &punctuations) {
|
||||
std::vector<std::string> punctuation_list;
|
||||
SplitStringToVector(punctuations, " ", false, &punctuation_list);
|
||||
for (auto &s : punctuation_list) {
|
||||
@@ -118,46 +215,4 @@ Lexicon::Lexicon(const std::string &lexicon, const std::string &tokens,
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<int64_t> Lexicon::ConvertTextToTokenIds(
|
||||
const std::string &_text) const {
|
||||
std::string text(_text);
|
||||
ToLowerCase(&text);
|
||||
|
||||
std::vector<std::string> words;
|
||||
SplitStringToVector(text, " ", false, &words);
|
||||
|
||||
std::vector<int64_t> ans;
|
||||
for (auto w : words) {
|
||||
std::vector<int64_t> prefix;
|
||||
while (!w.empty() && punctuations_.count(std::string(1, w[0]))) {
|
||||
// if w begins with a punctuation
|
||||
prefix.push_back(token2id_.at(std::string(1, w[0])));
|
||||
w = std::string(w.begin() + 1, w.end());
|
||||
}
|
||||
|
||||
std::vector<int64_t> suffix;
|
||||
while (!w.empty() && punctuations_.count(std::string(1, w.back()))) {
|
||||
suffix.push_back(token2id_.at(std::string(1, w.back())));
|
||||
w = std::string(w.begin(), w.end() - 1);
|
||||
}
|
||||
|
||||
if (!word2ids_.count(w)) {
|
||||
SHERPA_ONNX_LOGE("OOV %s. Ignore it!", w.c_str());
|
||||
continue;
|
||||
}
|
||||
|
||||
const auto &token_ids = word2ids_.at(w);
|
||||
ans.insert(ans.end(), prefix.begin(), prefix.end());
|
||||
ans.insert(ans.end(), token_ids.begin(), token_ids.end());
|
||||
ans.insert(ans.end(), suffix.rbegin(), suffix.rend());
|
||||
ans.push_back(blank_);
|
||||
}
|
||||
|
||||
if (!ans.empty()) {
|
||||
ans.resize(ans.size() - 1);
|
||||
}
|
||||
|
||||
return ans;
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
Reference in New Issue
Block a user