// sherpa-onnx/csrc/lexicon.cc // // Copyright (c) 2022-2023 Xiaomi Corporation #include "sherpa-onnx/csrc/lexicon.h" #include #include #include #include #include #include "sherpa-onnx/csrc/macros.h" #include "sherpa-onnx/csrc/text-utils.h" namespace sherpa_onnx { static void ToLowerCase(std::string *in_out) { std::transform(in_out->begin(), in_out->end(), in_out->begin(), [](unsigned char c) { return std::tolower(c); }); } // Note: We don't use SymbolTable here since tokens may contain a blank // in the first column static std::unordered_map ReadTokens( const std::string &tokens) { std::unordered_map token2id; std::ifstream is(tokens); std::string line; std::string sym; int32_t id; while (std::getline(is, line)) { std::istringstream iss(line); iss >> sym; if (iss.eof()) { id = atoi(sym.c_str()); sym = " "; } else { iss >> id; } if (!iss.eof()) { SHERPA_ONNX_LOGE("Error: %s", line.c_str()); exit(-1); } #if 0 if (token2id.count(sym)) { SHERPA_ONNX_LOGE("Duplicated token %s. Line %s. Existing ID: %d", sym.c_str(), line.c_str(), token2id.at(sym)); exit(-1); } #endif token2id.insert({std::move(sym), id}); } return token2id; } static std::vector ConvertTokensToIds( const std::unordered_map &token2id, const std::vector &tokens) { std::vector ids; ids.reserve(tokens.size()); for (const auto &s : tokens) { if (!token2id.count(s)) { return {}; } int32_t id = token2id.at(s); ids.push_back(id); } return ids; } Lexicon::Lexicon(const std::string &lexicon, const std::string &tokens, const std::string &punctuations) { token2id_ = ReadTokens(tokens); blank_ = token2id_.at(" "); std::ifstream is(lexicon); std::string word; std::vector token_list; std::string line; std::string phone; while (std::getline(is, line)) { std::istringstream iss(line); token_list.clear(); iss >> word; ToLowerCase(&word); if (word2ids_.count(word)) { SHERPA_ONNX_LOGE("Duplicated word: %s", word.c_str()); return; } while (iss >> phone) { token_list.push_back(std::move(phone)); } std::vector ids = ConvertTokensToIds(token2id_, token_list); if (ids.empty()) { continue; } word2ids_.insert({std::move(word), std::move(ids)}); } // process punctuations std::vector punctuation_list; SplitStringToVector(punctuations, " ", false, &punctuation_list); for (auto &s : punctuation_list) { punctuations_.insert(std::move(s)); } } std::vector Lexicon::ConvertTextToTokenIds( const std::string &_text) const { std::string text(_text); ToLowerCase(&text); std::vector words; SplitStringToVector(text, " ", false, &words); std::vector ans; for (auto w : words) { std::vector prefix; while (!w.empty() && punctuations_.count(std::string(1, w[0]))) { // if w begins with a punctuation prefix.push_back(token2id_.at(std::string(1, w[0]))); w = std::string(w.begin() + 1, w.end()); } std::vector suffix; while (!w.empty() && punctuations_.count(std::string(1, w.back()))) { suffix.push_back(token2id_.at(std::string(1, w.back()))); w = std::string(w.begin(), w.end() - 1); } if (!word2ids_.count(w)) { SHERPA_ONNX_LOGE("OOV %s. Ignore it!", w.c_str()); continue; } const auto &token_ids = word2ids_.at(w); ans.insert(ans.end(), prefix.begin(), prefix.end()); ans.insert(ans.end(), token_ids.begin(), token_ids.end()); ans.insert(ans.end(), suffix.rbegin(), suffix.rend()); ans.push_back(blank_); } if (!ans.empty()) { ans.resize(ans.size() - 1); } return ans; } } // namespace sherpa_onnx