Validate input sid (#369)

This commit is contained in:
Fangjun Kuang
2023-10-18 14:02:01 +08:00
committed by GitHub
parent 1ee79e3ff5
commit 8545c3b7f0
8 changed files with 51 additions and 14 deletions

View File

@@ -104,9 +104,17 @@ std::vector<int64_t> Lexicon::ConvertTextToTokenIdsChinese(
std::vector<int64_t> ans;
ans.push_back(token2id_.at("sil"));
auto sil = token2id_.at("sil");
auto eos = token2id_.at("eos");
ans.push_back(sil);
for (const auto &w : words) {
if (punctuations_.count(w)) {
ans.push_back(sil);
continue;
}
if (!word2ids_.count(w)) {
SHERPA_ONNX_LOGE("OOV %s. Ignore it!", w.c_str());
continue;
@@ -115,8 +123,8 @@ std::vector<int64_t> Lexicon::ConvertTextToTokenIdsChinese(
const auto &token_ids = word2ids_.at(w);
ans.insert(ans.end(), token_ids.begin(), token_ids.end());
}
ans.push_back(token2id_.at("sil"));
ans.push_back(token2id_.at("eos"));
ans.push_back(sil);
ans.push_back(eos);
return ans;
}
@@ -126,6 +134,7 @@ std::vector<int64_t> Lexicon::ConvertTextToTokenIdsEnglish(
ToLowerCase(&text);
std::vector<std::string> words = SplitUtf8(text);
int32_t blank = token2id_.at(" ");
std::vector<int64_t> ans;
for (const auto &w : words) {
@@ -141,12 +150,10 @@ std::vector<int64_t> Lexicon::ConvertTextToTokenIdsEnglish(
const auto &token_ids = word2ids_.at(w);
ans.insert(ans.end(), token_ids.begin(), token_ids.end());
if (blank_ != -1) {
ans.push_back(blank_);
}
ans.push_back(blank);
}
if (blank_ != -1 && !ans.empty()) {
if (!ans.empty()) {
// remove the last blank
ans.resize(ans.size() - 1);
}
@@ -156,9 +163,6 @@ std::vector<int64_t> Lexicon::ConvertTextToTokenIdsEnglish(
void Lexicon::InitTokens(const std::string &tokens) {
token2id_ = ReadTokens(tokens);
if (token2id_.count(" ")) {
blank_ = token2id_.at(" ");
}
}
void Lexicon::InitLanguage(const std::string &_lang) {