// sherpa-onnx/csrc/offline-punctuation-ct-transformer-impl.h // // Copyright (c) 2024 Xiaomi Corporation #ifndef SHERPA_ONNX_CSRC_OFFLINE_PUNCTUATION_CT_TRANSFORMER_IMPL_H_ #define SHERPA_ONNX_CSRC_OFFLINE_PUNCTUATION_CT_TRANSFORMER_IMPL_H_ #include #include #include #include #include "sherpa-onnx/csrc/macros.h" #include "sherpa-onnx/csrc/math.h" #include "sherpa-onnx/csrc/offline-ct-transformer-model.h" #include "sherpa-onnx/csrc/offline-punctuation-impl.h" #include "sherpa-onnx/csrc/offline-punctuation.h" #include "sherpa-onnx/csrc/text-utils.h" namespace sherpa_onnx { class OfflinePunctuationCtTransformerImpl : public OfflinePunctuationImpl { public: explicit OfflinePunctuationCtTransformerImpl( const OfflinePunctuationConfig &config) : config_(config), model_(config.model) {} std::string AddPunctuation(const std::string &text) const override { if (text.empty()) { return {}; } std::vector tokens = SplitUtf8(text); std::vector token_ids; token_ids.reserve(tokens.size()); const auto &meta_data = model_.GetModelMetadata(); for (const auto &t : tokens) { std::string token = ToLowerCase(t); if (meta_data.token2id.count(token)) { token_ids.push_back(meta_data.token2id.at(token)); } else { token_ids.push_back(meta_data.unk_id); } } auto memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); int32_t segment_size = 20; int32_t max_len = 200; int32_t num_segments = (token_ids.size() + segment_size - 1) / segment_size; std::vector punctuations; int32_t last = -1; for (int32_t i = 0; i != num_segments; ++i) { int32_t this_start = i * segment_size; // inclusive int32_t this_end = this_start + segment_size; // exclusive if (this_end > token_ids.size()) { this_end = token_ids.size(); } if (last != -1) { this_start = last; } // token_ids[this_start:this_end] is sent to the model std::array x_shape = {1, this_end - this_start}; Ort::Value x = Ort::Value::CreateTensor(memory_info, token_ids.data() + this_start, x_shape[1], x_shape.data(), x_shape.size()); int64_t len_shape = 1; int32_t len = x_shape[1]; Ort::Value x_len = Ort::Value::CreateTensor(memory_info, &len, 1, &len_shape, 1); Ort::Value out = model_.Forward(std::move(x), std::move(x_len)); // [N, T, num_punctuations] std::vector out_shape = out.GetTensorTypeAndShapeInfo().GetShape(); assert(out_shape[0] == 1); assert(out_shape[1] == len); assert(out_shape[2] == meta_data.num_punctuations); std::vector this_punctuations; this_punctuations.reserve(len); const float *p = out.GetTensorData(); for (int32_t k = 0; k != len; ++k, p += meta_data.num_punctuations) { auto index = static_cast(std::distance( p, std::max_element(p, p + meta_data.num_punctuations))); this_punctuations.push_back(index); } // for (int32_t k = 0; k != len; ++k, p += meta_data.num_punctuations) int32_t dot_index = -1; int32_t comma_index = -1; for (int32_t m = this_punctuations.size() - 1; m >= 1; --m) { int32_t punct_id = this_punctuations[m]; if (punct_id == meta_data.dot_id || punct_id == meta_data.quest_id) { dot_index = m; break; } if (comma_index == -1 && punct_id == meta_data.comma_id) { comma_index = m; } } // for (int32_t k = this_punctuations.size() - 1; k >= 1; --k) if (dot_index == -1 && len >= max_len && comma_index != -1) { dot_index = comma_index; this_punctuations[dot_index] = meta_data.dot_id; } if (dot_index == -1) { if (last == -1) { last = this_start; } if (i == num_segments - 1) { dot_index = token_ids.size() - 1; } } else { last = this_start + dot_index + 1; punctuations.insert(punctuations.end(), this_punctuations.begin(), this_punctuations.begin() + (dot_index + 1)); } } // for (int32_t i = 0; i != num_segments; ++i) if (punctuations.size() != token_ids.size() && punctuations.size() + 1 == token_ids.size()) { punctuations.push_back(meta_data.dot_id); } if (punctuations.size() != token_ids.size()) { SHERPA_ONNX_LOGE("%s, %d, %d. Some unexpected things happened", text.c_str(), static_cast(punctuations.size()), static_cast(token_ids.size())); return text; } std::string ans; for (int32_t i = 0; i != static_cast(punctuations.size()); ++i) { const std::string &w = tokens[i]; if (i > 0 && !(ans.back() & 0x80) && !(w[0] & 0x80)) { ans.push_back(' '); } ans.append(w); if (punctuations[i] != meta_data.underline_id) { ans.append(meta_data.id2punct[punctuations[i]]); } } return ans; } private: OfflinePunctuationConfig config_; OfflineCtTransformerModel model_; }; } // namespace sherpa_onnx #endif // SHERPA_ONNX_CSRC_OFFLINE_PUNCTUATION_CT_TRANSFORMER_IMPL_H_