Validate input sid (#369)
This commit is contained in:
@@ -104,9 +104,17 @@ std::vector<int64_t> Lexicon::ConvertTextToTokenIdsChinese(
|
||||
|
||||
std::vector<int64_t> ans;
|
||||
|
||||
ans.push_back(token2id_.at("sil"));
|
||||
auto sil = token2id_.at("sil");
|
||||
auto eos = token2id_.at("eos");
|
||||
|
||||
ans.push_back(sil);
|
||||
|
||||
for (const auto &w : words) {
|
||||
if (punctuations_.count(w)) {
|
||||
ans.push_back(sil);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!word2ids_.count(w)) {
|
||||
SHERPA_ONNX_LOGE("OOV %s. Ignore it!", w.c_str());
|
||||
continue;
|
||||
@@ -115,8 +123,8 @@ std::vector<int64_t> Lexicon::ConvertTextToTokenIdsChinese(
|
||||
const auto &token_ids = word2ids_.at(w);
|
||||
ans.insert(ans.end(), token_ids.begin(), token_ids.end());
|
||||
}
|
||||
ans.push_back(token2id_.at("sil"));
|
||||
ans.push_back(token2id_.at("eos"));
|
||||
ans.push_back(sil);
|
||||
ans.push_back(eos);
|
||||
return ans;
|
||||
}
|
||||
|
||||
@@ -126,6 +134,7 @@ std::vector<int64_t> Lexicon::ConvertTextToTokenIdsEnglish(
|
||||
ToLowerCase(&text);
|
||||
|
||||
std::vector<std::string> words = SplitUtf8(text);
|
||||
int32_t blank = token2id_.at(" ");
|
||||
|
||||
std::vector<int64_t> ans;
|
||||
for (const auto &w : words) {
|
||||
@@ -141,12 +150,10 @@ std::vector<int64_t> Lexicon::ConvertTextToTokenIdsEnglish(
|
||||
|
||||
const auto &token_ids = word2ids_.at(w);
|
||||
ans.insert(ans.end(), token_ids.begin(), token_ids.end());
|
||||
if (blank_ != -1) {
|
||||
ans.push_back(blank_);
|
||||
}
|
||||
ans.push_back(blank);
|
||||
}
|
||||
|
||||
if (blank_ != -1 && !ans.empty()) {
|
||||
if (!ans.empty()) {
|
||||
// remove the last blank
|
||||
ans.resize(ans.size() - 1);
|
||||
}
|
||||
@@ -156,9 +163,6 @@ std::vector<int64_t> Lexicon::ConvertTextToTokenIdsEnglish(
|
||||
|
||||
void Lexicon::InitTokens(const std::string &tokens) {
|
||||
token2id_ = ReadTokens(tokens);
|
||||
if (token2id_.count(" ")) {
|
||||
blank_ = token2id_.at(" ");
|
||||
}
|
||||
}
|
||||
|
||||
void Lexicon::InitLanguage(const std::string &_lang) {
|
||||
|
||||
Reference in New Issue
Block a user