// sherpa-onnx/csrc/online-punctuation-cnn-bilstm-impl.h // // Copyright (c) 2024 Jian You (jianyou@cisco.com, Cisco Systems) #ifndef SHERPA_ONNX_CSRC_ONLINE_PUNCTUATION_CNN_BILSTM_IMPL_H_ #define SHERPA_ONNX_CSRC_ONLINE_PUNCTUATION_CNN_BILSTM_IMPL_H_ #include #include #include #include #include #include #if __ANDROID_API__ >= 9 #include "android/asset_manager.h" #include "android/asset_manager_jni.h" #endif #include // NOLINT #include "sherpa-onnx/csrc/macros.h" #include "sherpa-onnx/csrc/math.h" #include "sherpa-onnx/csrc/online-cnn-bilstm-model-meta-data.h" #include "sherpa-onnx/csrc/online-cnn-bilstm-model.h" #include "sherpa-onnx/csrc/online-punctuation-impl.h" #include "sherpa-onnx/csrc/online-punctuation.h" #include "sherpa-onnx/csrc/onnx-utils.h" #include "sherpa-onnx/csrc/text-utils.h" #include "ssentencepiece/csrc/ssentencepiece.h" namespace sherpa_onnx { static const int32_t kMaxSeqLen = 200; class OnlinePunctuationCNNBiLSTMImpl : public OnlinePunctuationImpl { public: explicit OnlinePunctuationCNNBiLSTMImpl(const OnlinePunctuationConfig &config) : config_(config), model_(config.model) { if (!config_.model.bpe_vocab.empty()) { bpe_encoder_ = std::make_unique( config_.model.bpe_vocab); } } #if __ANDROID_API__ >= 9 OnlinePunctuationCNNBiLSTMImpl(AAssetManager *mgr, const OnlinePunctuationConfig &config) : config_(config), model_(mgr, config.model) { if (!config_.model.bpe_vocab.empty()) { auto buf = ReadFile(mgr, config_.model.bpe_vocab); std::istringstream iss(std::string(buf.begin(), buf.end())); bpe_encoder_ = std::make_unique(iss); } } #endif std::string AddPunctuationWithCase(const std::string &text) const override { if (text.empty()) { return {}; } std::vector tokens_list; // N * kMaxSeqLen std::vector valids_list; // N * kMaxSeqLen std::vector label_len_list; // N EncodeSentences(text, tokens_list, valids_list, label_len_list); const auto &meta_data = model_.GetModelMetadata(); auto memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); int32_t n = label_len_list.size(); std::array token_ids_shape = {n, kMaxSeqLen}; Ort::Value token_ids = Ort::Value::CreateTensor( memory_info, tokens_list.data(), tokens_list.size(), token_ids_shape.data(), token_ids_shape.size()); std::array valid_ids_shape = {n, kMaxSeqLen}; Ort::Value valid_ids = Ort::Value::CreateTensor( memory_info, valids_list.data(), valids_list.size(), valid_ids_shape.data(), valid_ids_shape.size()); std::array label_len_shape = {n}; Ort::Value label_len = Ort::Value::CreateTensor( memory_info, label_len_list.data(), label_len_list.size(), label_len_shape.data(), label_len_shape.size()); auto pair = model_.Forward(std::move(token_ids), std::move(valid_ids), std::move(label_len)); std::vector case_pred; std::vector punct_pred; const float *active_case_logits = pair.first.GetTensorData(); const float *active_punct_logits = pair.second.GetTensorData(); std::vector case_logits_shape = pair.first.GetTensorTypeAndShapeInfo().GetShape(); for (int32_t i = 0; i < case_logits_shape[0]; ++i) { const float *p_cur_case = active_case_logits + i * meta_data.num_cases; auto index_case = static_cast(std::distance( p_cur_case, std::max_element(p_cur_case, p_cur_case + meta_data.num_cases))); case_pred.push_back(index_case); const float *p_cur_punct = active_punct_logits + i * meta_data.num_punctuations; auto index_punct = static_cast(std::distance( p_cur_punct, std::max_element(p_cur_punct, p_cur_punct + meta_data.num_punctuations))); punct_pred.push_back(index_punct); } std::string ans = DecodeSentences(text, case_pred, punct_pred); return ans; } private: void EncodeSentences(const std::string &text, std::vector &tokens_list, // NOLINT std::vector &valids_list, // NOLINT std::vector &label_len_list) const { // NOLINT std::vector tokens; std::vector valids; int32_t label_len = 0; tokens.push_back(1); // hardcode 1 now, 1 - valids.push_back(1); std::stringstream ss(text); std::string word; while (ss >> word) { std::vector word_tokens; bpe_encoder_->Encode(word, &word_tokens); int32_t seq_len = tokens.size() + word_tokens.size(); if (seq_len > kMaxSeqLen - 1) { tokens.push_back(2); // hardcode 2 now, 2 - valids.push_back(1); label_len = std::count(valids.begin(), valids.end(), 1); if (tokens.size() < kMaxSeqLen) { tokens.resize(kMaxSeqLen, 0); valids.resize(kMaxSeqLen, 0); } assert(tokens.size() == kMaxSeqLen); assert(valids.size() == kMaxSeqLen); tokens_list.insert(tokens_list.end(), tokens.begin(), tokens.end()); valids_list.insert(valids_list.end(), valids.begin(), valids.end()); label_len_list.push_back(label_len); std::vector().swap(tokens); std::vector().swap(valids); label_len = 0; tokens.push_back(1); // hardcode 1 now, 1 - valids.push_back(1); } tokens.insert(tokens.end(), word_tokens.begin(), word_tokens.end()); valids.push_back(1); // only the first sub word is valid int32_t remaining_size = static_cast(word_tokens.size()) - 1; if (remaining_size > 0) { int32_t valids_cur_size = static_cast(valids.size()); valids.resize(valids_cur_size + remaining_size, 0); } } if (tokens.size() > 0) { tokens.push_back(2); // hardcode 2 now, 2 - valids.push_back(1); label_len = std::count(valids.begin(), valids.end(), 1); if (tokens.size() < kMaxSeqLen) { tokens.resize(kMaxSeqLen, 0); valids.resize(kMaxSeqLen, 0); } assert(tokens.size() == kMaxSeqLen); assert(valids.size() == kMaxSeqLen); tokens_list.insert(tokens_list.end(), tokens.begin(), tokens.end()); valids_list.insert(valids_list.end(), valids.begin(), valids.end()); label_len_list.push_back(label_len); } } std::string DecodeSentences(const std::string &raw_text, const std::vector &case_pred, const std::vector &punct_pred) const { std::string result_text; std::istringstream iss(raw_text); std::vector words; std::string word; while (iss >> word) { words.emplace_back(word); } assert(words.size() == case_pred.size()); assert(words.size() == punct_pred.size()); for (int32_t i = 0; i < words.size(); ++i) { std::string prefix = ((i != 0) ? " " : ""); result_text += prefix; switch (case_pred[i]) { case 1: // upper { std::transform(words[i].begin(), words[i].end(), words[i].begin(), [](auto c) { return std::toupper(c); }); result_text += words[i]; break; } case 2: // cap { words[i][0] = std::toupper(words[i][0]); result_text += words[i]; break; } case 3: // mix case { // TODO(frankyoujian): // Need to add a map containing supported mix case words so that we // can fetch the predicted word from the map e.g. mcdonald's -> // McDonald's result_text += words[i]; break; } default: { result_text += words[i]; break; } } std::string suffix; switch (punct_pred[i]) { case 1: // comma { suffix = ","; break; } case 2: // period { suffix = "."; break; } case 3: // question { suffix = "?"; break; } default: break; } result_text += suffix; } return result_text; } private: OnlinePunctuationConfig config_; OnlineCNNBiLSTMModel model_; std::unique_ptr bpe_encoder_; }; } // namespace sherpa_onnx #endif // SHERPA_ONNX_CSRC_ONLINE_PUNCTUATION_CNN_BILSTM_IMPL_H_