From 1414e4dc619959660574f39c5e90183d677e8cd0 Mon Sep 17 00:00:00 2001 From: jianyou Date: Tue, 6 Aug 2024 17:33:38 +0800 Subject: [PATCH] Add online punctuation and casing prediction model for English language (#1224) --- sherpa-onnx/csrc/CMakeLists.txt | 6 + .../csrc/online-cnn-bilstm-model-meta-data.h | 25 ++ sherpa-onnx/csrc/online-cnn-bilstm-model.cc | 135 +++++++++ sherpa-onnx/csrc/online-cnn-bilstm-model.h | 61 ++++ .../csrc/online-punctuation-cnn-bilstm-impl.h | 268 ++++++++++++++++++ sherpa-onnx/csrc/online-punctuation-impl.cc | 39 +++ sherpa-onnx/csrc/online-punctuation-impl.h | 37 +++ .../csrc/online-punctuation-model-config.cc | 68 +++++ .../csrc/online-punctuation-model-config.h | 42 +++ sherpa-onnx/csrc/online-punctuation.cc | 53 ++++ sherpa-onnx/csrc/online-punctuation.h | 58 ++++ sherpa-onnx/csrc/session.cc | 5 + sherpa-onnx/csrc/session.h | 4 + .../csrc/sherpa-onnx-online-punctuation.cc | 73 +++++ 14 files changed, 874 insertions(+) create mode 100644 sherpa-onnx/csrc/online-cnn-bilstm-model-meta-data.h create mode 100644 sherpa-onnx/csrc/online-cnn-bilstm-model.cc create mode 100644 sherpa-onnx/csrc/online-cnn-bilstm-model.h create mode 100644 sherpa-onnx/csrc/online-punctuation-cnn-bilstm-impl.h create mode 100644 sherpa-onnx/csrc/online-punctuation-impl.cc create mode 100644 sherpa-onnx/csrc/online-punctuation-impl.h create mode 100644 sherpa-onnx/csrc/online-punctuation-model-config.cc create mode 100644 sherpa-onnx/csrc/online-punctuation-model-config.h create mode 100644 sherpa-onnx/csrc/online-punctuation.cc create mode 100644 sherpa-onnx/csrc/online-punctuation.h create mode 100644 sherpa-onnx/csrc/sherpa-onnx-online-punctuation.cc diff --git a/sherpa-onnx/csrc/CMakeLists.txt b/sherpa-onnx/csrc/CMakeLists.txt index bff22a04..477de5f1 100644 --- a/sherpa-onnx/csrc/CMakeLists.txt +++ b/sherpa-onnx/csrc/CMakeLists.txt @@ -138,6 +138,10 @@ list(APPEND sources offline-punctuation-impl.cc offline-punctuation-model-config.cc offline-punctuation.cc + online-cnn-bilstm-model.cc + online-punctuation-impl.cc + online-punctuation-model-config.cc + online-punctuation.cc ) if(SHERPA_ONNX_ENABLE_TTS) @@ -243,6 +247,7 @@ if(SHERPA_ONNX_ENABLE_BINARY) add_executable(sherpa-onnx-offline-language-identification sherpa-onnx-offline-language-identification.cc) add_executable(sherpa-onnx-offline-parallel sherpa-onnx-offline-parallel.cc) add_executable(sherpa-onnx-offline-punctuation sherpa-onnx-offline-punctuation.cc) + add_executable(sherpa-onnx-online-punctuation sherpa-onnx-online-punctuation.cc) if(SHERPA_ONNX_ENABLE_TTS) add_executable(sherpa-onnx-offline-tts sherpa-onnx-offline-tts.cc) @@ -256,6 +261,7 @@ if(SHERPA_ONNX_ENABLE_BINARY) sherpa-onnx-offline-language-identification sherpa-onnx-offline-parallel sherpa-onnx-offline-punctuation + sherpa-onnx-online-punctuation ) if(SHERPA_ONNX_ENABLE_TTS) list(APPEND main_exes diff --git a/sherpa-onnx/csrc/online-cnn-bilstm-model-meta-data.h b/sherpa-onnx/csrc/online-cnn-bilstm-model-meta-data.h new file mode 100644 index 00000000..e9532c44 --- /dev/null +++ b/sherpa-onnx/csrc/online-cnn-bilstm-model-meta-data.h @@ -0,0 +1,25 @@ +// sherpa-onnx/csrc/online-cnn-bilstm-model-meta-data.h +// +// Copyright (c) 2024 Jian You (jianyou@cisco.com, Cisco Systems) + +#ifndef SHERPA_ONNX_CSRC_ONLINE_CNN_BILSTM_MODEL_META_DATA_H_ +#define SHERPA_ONNX_CSRC_ONLINE_CNN_BILSTM_MODEL_META_DATA_H_ + +namespace sherpa_onnx { + +struct OnlineCNNBiLSTMModelMetaData { + int32_t comma_id; + int32_t period_id; + int32_t quest_id; + + int32_t upper_id; + int32_t cap_id; + int32_t mix_case_id; + + int32_t num_cases; + int32_t num_punctuations; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_ONLINE_CNN_BILSTM_MODEL_META_DATA_H_ diff --git a/sherpa-onnx/csrc/online-cnn-bilstm-model.cc b/sherpa-onnx/csrc/online-cnn-bilstm-model.cc new file mode 100644 index 00000000..739cf83f --- /dev/null +++ b/sherpa-onnx/csrc/online-cnn-bilstm-model.cc @@ -0,0 +1,135 @@ +// sherpa-onnx/csrc/online-cnn-bilstm-model.cc +// +// Copyright (c) 2024 Jian You (jianyou@cisco.com, Cisco Systems) + +#include "sherpa-onnx/csrc/online-cnn-bilstm-model.h" + +#include +#include + +#include "sherpa-onnx/csrc/onnx-utils.h" +#include "sherpa-onnx/csrc/session.h" +#include "sherpa-onnx/csrc/text-utils.h" + +namespace sherpa_onnx { + +class OnlineCNNBiLSTMModel::Impl { + public: + explicit Impl(const OnlinePunctuationModelConfig &config) + : config_(config), + env_(ORT_LOGGING_LEVEL_ERROR), + sess_opts_(GetSessionOptions(config)), + allocator_{} { + auto buf = ReadFile(config_.cnn_bilstm); + Init(buf.data(), buf.size()); + } + +#if __ANDROID_API__ >= 9 + Impl(AAssetManager *mgr, const OnlinePunctuationModelConfig &config) + : config_(config), + env_(ORT_LOGGING_LEVEL_ERROR), + sess_opts_(GetSessionOptions(config)), + allocator_{} { + auto buf = ReadFile(mgr, config_.cnn_bilstm); + Init(buf.data(), buf.size()); + } +#endif + + std::pair Forward(Ort::Value token_ids, Ort::Value valid_ids, Ort::Value label_lens) { + std::array inputs = {std::move(token_ids), std::move(valid_ids), std::move(label_lens)}; + + auto ans = + sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(), + output_names_ptr_.data(), output_names_ptr_.size()); + return {std::move(ans[0]), std::move(ans[1])}; + } + + OrtAllocator *Allocator() const { return allocator_; } + + const OnlineCNNBiLSTMModelMetaData &GetModelMetadata() const { + return meta_data_; + } + + private: + void Init(void *model_data, size_t model_data_length) { + sess_ = std::make_unique(env_, model_data, model_data_length, + sess_opts_); + + GetInputNames(sess_.get(), &input_names_, &input_names_ptr_); + + GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_); + + // get meta data + Ort::ModelMetadata meta_data = sess_->GetModelMetadata(); + + Ort::AllocatorWithDefaultOptions allocator; // used in the macro below + + SHERPA_ONNX_READ_META_DATA(meta_data_.comma_id, "COMMA"); + SHERPA_ONNX_READ_META_DATA(meta_data_.period_id, "PERIOD"); + SHERPA_ONNX_READ_META_DATA(meta_data_.quest_id, "QUESTION"); + + // assert here, because we will use the constant value + assert(meta_data_.comma_id == 1); + assert(meta_data_.period_id == 2); + assert(meta_data_.quest_id == 3); + + SHERPA_ONNX_READ_META_DATA(meta_data_.upper_id, "UPPER"); + SHERPA_ONNX_READ_META_DATA(meta_data_.cap_id, "CAP"); + SHERPA_ONNX_READ_META_DATA(meta_data_.mix_case_id, "MIX_CASE"); + + assert(meta_data_.upper_id == 1); + assert(meta_data_.cap_id == 2); + assert(meta_data_.mix_case_id == 3); + + // output shape is (T', num_cases) + meta_data_.num_cases = + sess_->GetOutputTypeInfo(0).GetTensorTypeAndShapeInfo().GetShape()[1]; + meta_data_.num_punctuations = + sess_->GetOutputTypeInfo(1).GetTensorTypeAndShapeInfo().GetShape()[1]; + } + + private: + OnlinePunctuationModelConfig config_; + Ort::Env env_; + Ort::SessionOptions sess_opts_; + Ort::AllocatorWithDefaultOptions allocator_; + + std::unique_ptr sess_; + + std::vector input_names_; + std::vector input_names_ptr_; + + std::vector output_names_; + std::vector output_names_ptr_; + + OnlineCNNBiLSTMModelMetaData meta_data_; +}; + +OnlineCNNBiLSTMModel::OnlineCNNBiLSTMModel( + const OnlinePunctuationModelConfig &config) + : impl_(std::make_unique(config)) {} + +#if __ANDROID_API__ >= 9 +OnlineCNNBiLSTMModel::OnlineCNNBiLSTMModel( + AAssetManager *mgr, const OnlinePunctuationModelConfig &config) + : impl_(std::make_unique(mgr, config)) {} +#endif + +OnlineCNNBiLSTMModel::~OnlineCNNBiLSTMModel() = default; + +std::pair OnlineCNNBiLSTMModel::Forward(Ort::Value token_ids, + Ort::Value valid_ids, + Ort::Value label_lens) const { + return impl_->Forward(std::move(token_ids), std::move(valid_ids), std::move(label_lens)); +} + +OrtAllocator *OnlineCNNBiLSTMModel::Allocator() const { + return impl_->Allocator(); +} + +const OnlineCNNBiLSTMModelMetaData & +OnlineCNNBiLSTMModel::GetModelMetadata() const { + return impl_->GetModelMetadata(); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/online-cnn-bilstm-model.h b/sherpa-onnx/csrc/online-cnn-bilstm-model.h new file mode 100644 index 00000000..aa0ca2d3 --- /dev/null +++ b/sherpa-onnx/csrc/online-cnn-bilstm-model.h @@ -0,0 +1,61 @@ +// sherpa-onnx/csrc/online-cnn-bilstm-model.h +// +// Copyright (c) 2024 Jian You (jianyou@cisco.com, Cisco Systems) + +#ifndef SHERPA_ONNX_CSRC_ONLINE_CNN_BILSTM_MODEL_H_ +#define SHERPA_ONNX_CSRC_ONLINE_CNN_BILSTM_MODEL_H_ +#include +#include + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#include "onnxruntime_cxx_api.h" // NOLINT +#include "sherpa-onnx/csrc/online-cnn-bilstm-model-meta-data.h" +#include "sherpa-onnx/csrc/online-punctuation-model-config.h" + +namespace sherpa_onnx { + +/** This class implements + * https://github.com/frankyoujian/Edge-Punct-Casing/blob/main/onnx_decode_sentence.py + */ +class OnlineCNNBiLSTMModel { + public: + explicit OnlineCNNBiLSTMModel( + const OnlinePunctuationModelConfig &config); + +#if __ANDROID_API__ >= 9 + OnlineCNNBiLSTMModel(AAssetManager *mgr, + const OnlinePunctuationModelConfig &config); +#endif + + ~OnlineCNNBiLSTMModel(); + + /** Run the forward method of the model. + * + * @param token_ids A tensor of shape (N, T) of dtype int32. + * @param valid_ids A tensor of shape (N, T) of dtype int32. + * @param label_lens A tensor of shape (N) of dtype int32. + * + * @return Return a pair of tensors + * - case_logits: A 2-D tensor of shape (T', num_cases). + * - punct_logits: A 2-D tensor of shape (T', num_puncts). + */ + std::pair Forward(Ort::Value token_ids, Ort::Value valid_ids, Ort::Value label_lens) const; + + /** Return an allocator for allocating memory + */ + OrtAllocator *Allocator() const; + + const OnlineCNNBiLSTMModelMetaData &GetModelMetadata() const; + + private: + class Impl; + std::unique_ptr impl_; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_ONLINE_CNN_BILSTM_MODEL_H_ diff --git a/sherpa-onnx/csrc/online-punctuation-cnn-bilstm-impl.h b/sherpa-onnx/csrc/online-punctuation-cnn-bilstm-impl.h new file mode 100644 index 00000000..aca25bb0 --- /dev/null +++ b/sherpa-onnx/csrc/online-punctuation-cnn-bilstm-impl.h @@ -0,0 +1,268 @@ +// sherpa-onnx/csrc/online-punctuation-cnn-bilstm-impl.h +// +// Copyright (c) 2024 Jian You (jianyou@cisco.com, Cisco Systems) + +#ifndef SHERPA_ONNX_CSRC_ONLINE_PUNCTUATION_CNN_BILSTM_IMPL_H_ +#define SHERPA_ONNX_CSRC_ONLINE_PUNCTUATION_CNN_BILSTM_IMPL_H_ + +#include + +#include +#include +#include +#include +#include + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/csrc/math.h" +#include "sherpa-onnx/csrc/online-cnn-bilstm-model.h" +#include "sherpa-onnx/csrc/online-punctuation-impl.h" +#include "sherpa-onnx/csrc/online-punctuation.h" +#include "sherpa-onnx/csrc/online-cnn-bilstm-model-meta-data.h" +#include "sherpa-onnx/csrc/text-utils.h" +#include "sherpa-onnx/csrc/onnx-utils.h" +#include "ssentencepiece/csrc/ssentencepiece.h" +#include // NOLINT + +namespace sherpa_onnx { + +static const int32_t kMaxSeqLen = 200; + +class OnlinePunctuationCNNBiLSTMImpl : public OnlinePunctuationImpl { + public: + explicit OnlinePunctuationCNNBiLSTMImpl( + const OnlinePunctuationConfig &config) + : config_(config), model_(config.model) { + if (!config_.model.bpe_vocab.empty()) { + bpe_encoder_ = std::make_unique( + config_.model.bpe_vocab); + } + } + +#if __ANDROID_API__ >= 9 + OnlinePunctuationCNNBiLSTMImpl(AAssetManager *mgr, + const OnlinePunctuationConfig &config) + : config_(config), model_(mgr, config.model) { + if (!config_.model.bpe_vocab.empty()) { + auto buf = ReadFile(mgr, config_.model.bpe_vocab); + std::istringstream iss(std::string(buf.begin(), buf.end())); + bpe_encoder_ = std::make_unique(iss); + } + } +#endif + + std::string AddPunctuationWithCase(const std::string &text) const override { + if (text.empty()) { + return {}; + } + + std::vector tokens_list; // N * kMaxSeqLen + std::vector valids_list; // N * kMaxSeqLen + std::vector label_len_list; // N + + EncodeSentences(text, tokens_list, valids_list, label_len_list); + + const auto &meta_data = model_.GetModelMetadata(); + + auto memory_info = + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); + + int32_t n = label_len_list.size(); + + std::array token_ids_shape = {n, kMaxSeqLen}; + Ort::Value token_ids = Ort::Value::CreateTensor(memory_info, tokens_list.data(), tokens_list.size(), + token_ids_shape.data(), token_ids_shape.size()); + + std::array valid_ids_shape = {n, kMaxSeqLen}; + Ort::Value valid_ids = Ort::Value::CreateTensor(memory_info, valids_list.data(), valids_list.size(), + valid_ids_shape.data(), valid_ids_shape.size()); + + std::array label_len_shape = {n}; + Ort::Value label_len = Ort::Value::CreateTensor(memory_info, label_len_list.data(), label_len_list.size(), + label_len_shape.data(), label_len_shape.size()); + + auto pair = model_.Forward(std::move(token_ids), std::move(valid_ids), std::move(label_len)); + + std::vector case_pred; + std::vector punct_pred; + const float* active_case_logits = pair.first.GetTensorData(); + const float* active_punct_logits = pair.second.GetTensorData(); + std::vector case_logits_shape = pair.first.GetTensorTypeAndShapeInfo().GetShape(); + + for (int32_t i = 0; i < case_logits_shape[0]; ++i) { + const float* p_cur_case = active_case_logits + i * meta_data.num_cases; + auto index_case = static_cast(std::distance( + p_cur_case, std::max_element(p_cur_case, p_cur_case + meta_data.num_cases))); + case_pred.push_back(index_case); + + const float* p_cur_punct = active_punct_logits + i * meta_data.num_punctuations; + auto index_punct = static_cast(std::distance( + p_cur_punct, std::max_element(p_cur_punct, p_cur_punct + meta_data.num_punctuations))); + punct_pred.push_back(index_punct); + } + + std::string ans = DecodeSentences(text, case_pred, punct_pred); + + return ans; + } + + private: + void EncodeSentences(const std::string& text, + std::vector& tokens_list, + std::vector& valids_list, + std::vector& label_len_list) const { + std::vector tokens; + std::vector valids; + int32_t label_len = 0; + + tokens.push_back(1); // hardcode 1 now, 1 - + valids.push_back(1); + + std::stringstream ss(text); + std::string word; + while (ss >> word) { + std::vector word_tokens; + bpe_encoder_->Encode(word, &word_tokens); + + int32_t seq_len = tokens.size() + word_tokens.size(); + if (seq_len > kMaxSeqLen - 1) { + tokens.push_back(2); // hardcode 2 now, 2 - + valids.push_back(1); + + label_len = std::count(valids.begin(), valids.end(), 1); + + if (tokens.size() < kMaxSeqLen) { + tokens.resize(kMaxSeqLen, 0); + valids.resize(kMaxSeqLen, 0); + } + + assert(tokens.size() == kMaxSeqLen); + assert(valids.size() == kMaxSeqLen); + + tokens_list.insert(tokens_list.end(), tokens.begin(), tokens.end()); + valids_list.insert(valids_list.end(), valids.begin(), valids.end()); + label_len_list.push_back(label_len); + + std::vector().swap(tokens); + std::vector().swap(valids); + label_len = 0; + tokens.push_back(1); // hardcode 1 now, 1 - + valids.push_back(1); + } + + tokens.insert(tokens.end(), word_tokens.begin(), word_tokens.end()); + valids.push_back(1); // only the first sub word is valid + int32_t remaining_size = static_cast(word_tokens.size()) - 1; + if (remaining_size > 0) { + int32_t valids_cur_size = static_cast(valids.size()); + valids.resize(valids_cur_size + remaining_size, 0); + } + } + + if (tokens.size() > 0) { + tokens.push_back(2); // hardcode 2 now, 2 - + valids.push_back(1); + + label_len = std::count(valids.begin(), valids.end(), 1); + + if (tokens.size() < kMaxSeqLen) { + tokens.resize(kMaxSeqLen, 0); + valids.resize(kMaxSeqLen, 0); + } + + assert(tokens.size() == kMaxSeqLen); + assert(valids.size() == kMaxSeqLen); + + tokens_list.insert(tokens_list.end(), tokens.begin(), tokens.end()); + valids_list.insert(valids_list.end(), valids.begin(), valids.end()); + label_len_list.push_back(label_len); + } + } + + std::string DecodeSentences(const std::string& raw_text, + const std::vector& case_pred, + const std::vector& punct_pred) const { + std::string result_text; + std::istringstream iss(raw_text); + std::vector words; + std::string word; + + while (iss >> word) { + words.emplace_back(word); + } + + assert(words.size() == case_pred.size()); + assert(words.size() == punct_pred.size()); + + for (int32_t i = 0; i < words.size(); ++i) { + std::string prefix = ((i != 0) ? " " : ""); + result_text += prefix; + switch (case_pred[i]) { + case 1: // upper + { + std::transform(words[i].begin(), words[i].end(), words[i].begin(), [](auto c){ return std::toupper(c); }); + result_text += words[i]; + break; + } + case 2: // cap + { + words[i][0] = std::toupper(words[i][0]); + result_text += words[i]; + break; + } + case 3: // mix case + { + // TODO: + // Need to add a map containing supported mix case words so that we can fetch the predicted word from the map + // e.g. mcdonald's -> McDonald's + result_text += words[i]; + break; + } + default: + { + result_text += words[i]; + break; + } + } + + std::string suffix; + switch (punct_pred[i]) { + case 1: // comma + { + suffix = ","; + break; + } + case 2: // period + { + suffix = "."; + break; + } + case 3: // question + { + suffix = "?"; + break; + } + default: + break; + } + + result_text += suffix; + } + + return result_text; + } + + private: + OnlinePunctuationConfig config_; + OnlineCNNBiLSTMModel model_; + std::unique_ptr bpe_encoder_; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_ONLINE_PUNCTUATION_CNN_BILSTM_IMPL_H_ diff --git a/sherpa-onnx/csrc/online-punctuation-impl.cc b/sherpa-onnx/csrc/online-punctuation-impl.cc new file mode 100644 index 00000000..2ff0050b --- /dev/null +++ b/sherpa-onnx/csrc/online-punctuation-impl.cc @@ -0,0 +1,39 @@ +// sherpa-onnx/csrc/online-punctuation-impl.cc +// +// Copyright (c) 2024 Jian You (jianyou@cisco.com, Cisco Systems) + +#include "sherpa-onnx/csrc/online-punctuation-impl.h" + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/csrc/online-punctuation-cnn-bilstm-impl.h" + +namespace sherpa_onnx { + +std::unique_ptr OnlinePunctuationImpl::Create( + const OnlinePunctuationConfig &config) { + if (!config.model.cnn_bilstm.empty() && !config.model.bpe_vocab.empty()) { + return std::make_unique(config); + } + + SHERPA_ONNX_LOGE("Please specify a punctuation model and bpe vocab! Return a null pointer"); + return nullptr; +} + +#if __ANDROID_API__ >= 9 +std::unique_ptr OnlinePunctuationImpl::Create( + AAssetManager *mgr, const OnlinePunctuationConfig &config) { + if (!config.model.cnn_bilstm.empty() && !config.model.bpe_vocab.empty()) { + return std::make_unique(mgr, config); + } + + SHERPA_ONNX_LOGE("Please specify a punctuation model and bpe vocab! Return a null pointer"); + return nullptr; +} +#endif + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/online-punctuation-impl.h b/sherpa-onnx/csrc/online-punctuation-impl.h new file mode 100644 index 00000000..456f594d --- /dev/null +++ b/sherpa-onnx/csrc/online-punctuation-impl.h @@ -0,0 +1,37 @@ +// sherpa-onnx/csrc/online-punctuation-impl.h +// +// Copyright (c) 2024 Jian You (jianyou@cisco.com, Cisco Systems) + +#ifndef SHERPA_ONNX_CSRC_ONLINE_PUNCTUATION_IMPL_H_ +#define SHERPA_ONNX_CSRC_ONLINE_PUNCTUATION_IMPL_H_ + +#include +#include +#include +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#include "sherpa-onnx/csrc/online-punctuation.h" + +namespace sherpa_onnx { + +class OnlinePunctuationImpl { + public: + virtual ~OnlinePunctuationImpl() = default; + + static std::unique_ptr Create( + const OnlinePunctuationConfig &config); + +#if __ANDROID_API__ >= 9 + static std::unique_ptr Create( + AAssetManager *mgr, const OnlinePunctuationConfig &config); +#endif + + virtual std::string AddPunctuationWithCase(const std::string &text) const = 0; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_ONLINE_PUNCTUATION_IMPL_H_ diff --git a/sherpa-onnx/csrc/online-punctuation-model-config.cc b/sherpa-onnx/csrc/online-punctuation-model-config.cc new file mode 100644 index 00000000..8c8b2a30 --- /dev/null +++ b/sherpa-onnx/csrc/online-punctuation-model-config.cc @@ -0,0 +1,68 @@ +// sherpa-onnx/csrc/online-punctuation-model-config.cc +// +// Copyright (c) 2024 Jian You (jianyou@cisco.com, Cisco Systems) + +#include "sherpa-onnx/csrc/online-punctuation-model-config.h" + +#include "sherpa-onnx/csrc/file-utils.h" +#include "sherpa-onnx/csrc/macros.h" + +namespace sherpa_onnx { + +void OnlinePunctuationModelConfig::Register(ParseOptions *po) { + po->Register("cnn-bilstm", &cnn_bilstm, + "Path to the light-weight CNN-BiLSTM model"); + + po->Register("bpe-vocab", &bpe_vocab, + "Path to the bpe vocab file"); + + po->Register("num-threads", &num_threads, + "Number of threads to run the neural network"); + + po->Register("debug", &debug, + "true to print model information while loading it."); + + po->Register("provider", &provider, + "Specify a provider to use: cpu, cuda, coreml"); +} + +bool OnlinePunctuationModelConfig::Validate() const { + if (cnn_bilstm.empty()) { + SHERPA_ONNX_LOGE("Please provide --cnn-bilstm"); + return false; + } + + if (!FileExists(cnn_bilstm)) { + SHERPA_ONNX_LOGE("--cnn-bilstm '%s' does not exist", + cnn_bilstm.c_str()); + return false; + } + + if (bpe_vocab.empty()) { + SHERPA_ONNX_LOGE("Please provide --bpe-vocab"); + return false; + } + + if (!FileExists(bpe_vocab)) { + SHERPA_ONNX_LOGE("--bpe-vocab '%s' does not exist", + bpe_vocab.c_str()); + return false; + } + + return true; +} + +std::string OnlinePunctuationModelConfig::ToString() const { + std::ostringstream os; + + os << "OnlinePunctuationModelConfig("; + os << "cnn_bilstm=\"" << cnn_bilstm << "\", "; + os << "bpe_vocab=\"" << bpe_vocab << "\", "; + os << "num_threads=" << num_threads << ", "; + os << "debug=" << (debug ? "True" : "False") << ", "; + os << "provider=\"" << provider << "\")"; + + return os.str(); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/online-punctuation-model-config.h b/sherpa-onnx/csrc/online-punctuation-model-config.h new file mode 100644 index 00000000..2ee2c7c3 --- /dev/null +++ b/sherpa-onnx/csrc/online-punctuation-model-config.h @@ -0,0 +1,42 @@ +// sherpa-onnx/csrc/online-punctuation-model-config.h +// +// Copyright (c) 2024 Jian You (jianyou@cisco.com, Cisco Systems) + +#ifndef SHERPA_ONNX_CSRC_ONLINE_PUNCTUATION_MODEL_CONFIG_H_ +#define SHERPA_ONNX_CSRC_ONLINE_PUNCTUATION_MODEL_CONFIG_H_ + +#include + +#include "sherpa-onnx/csrc/parse-options.h" + +namespace sherpa_onnx { + +struct OnlinePunctuationModelConfig { + std::string cnn_bilstm; + std::string bpe_vocab; + + int32_t num_threads = 1; + bool debug = false; + std::string provider = "cpu"; + + OnlinePunctuationModelConfig() = default; + + OnlinePunctuationModelConfig(const std::string &cnn_bilstm, + const std::string &bpe_vocab, + int32_t num_threads, bool debug, + const std::string &provider) + : cnn_bilstm(cnn_bilstm), + bpe_vocab(bpe_vocab), + num_threads(num_threads), + debug(debug), + provider(provider) {} + + void Register(ParseOptions *po); + bool Validate() const; + + std::string ToString() const; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_ONLINE_PUNCTUATION_MODEL_CONFIG_H_ diff --git a/sherpa-onnx/csrc/online-punctuation.cc b/sherpa-onnx/csrc/online-punctuation.cc new file mode 100644 index 00000000..754870a3 --- /dev/null +++ b/sherpa-onnx/csrc/online-punctuation.cc @@ -0,0 +1,53 @@ +// sherpa-onnx/csrc/online-punctuation.cc +// +// Copyright (c) 2024 Jian You (jianyou@cisco.com, Cisco Systems) + +#include "sherpa-onnx/csrc/online-punctuation.h" + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/csrc/online-punctuation-impl.h" + +namespace sherpa_onnx { + +void OnlinePunctuationConfig::Register(ParseOptions *po) { + model.Register(po); +} + +bool OnlinePunctuationConfig::Validate() const { + if (!model.Validate()) { + return false; + } + + return true; +} + +std::string OnlinePunctuationConfig::ToString() const { + std::ostringstream os; + + os << "OnlinePunctuationConfig("; + os << "model=" << model.ToString() << ")"; + + return os.str(); +} + +OnlinePunctuation::OnlinePunctuation(const OnlinePunctuationConfig &config) + : impl_(OnlinePunctuationImpl::Create(config)) {} + +#if __ANDROID_API__ >= 9 +OnlinePunctuation::OnlinePunctuation(AAssetManager *mgr, + const OnlinePunctuationConfig &config) + : impl_(OnlinePunctuationImpl::Create(mgr, config)) {} +#endif + +OnlinePunctuation::~OnlinePunctuation() = default; + +std::string OnlinePunctuation::AddPunctuationWithCase(const std::string &text) const { + return impl_->AddPunctuationWithCase(text); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/online-punctuation.h b/sherpa-onnx/csrc/online-punctuation.h new file mode 100644 index 00000000..a70d336c --- /dev/null +++ b/sherpa-onnx/csrc/online-punctuation.h @@ -0,0 +1,58 @@ +// sherpa-onnx/csrc/online-punctuation.h +// +// Copyright (c) 2024 Jian You (jianyou@cisco.com, Cisco Systems) + +#ifndef SHERPA_ONNX_CSRC_ONLINE_PUNCTUATION_H_ +#define SHERPA_ONNX_CSRC_ONLINE_PUNCTUATION_H_ + +#include +#include +#include + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#include "sherpa-onnx/csrc/online-punctuation-model-config.h" +#include "sherpa-onnx/csrc/parse-options.h" + +namespace sherpa_onnx { + +struct OnlinePunctuationConfig { + OnlinePunctuationModelConfig model; + + OnlinePunctuationConfig() = default; + + explicit OnlinePunctuationConfig(const OnlinePunctuationModelConfig &model) + : model(model) {} + + void Register(ParseOptions *po); + bool Validate() const; + + std::string ToString() const; +}; + +class OnlinePunctuationImpl; + +class OnlinePunctuation { + public: + explicit OnlinePunctuation(const OnlinePunctuationConfig &config); + +#if __ANDROID_API__ >= 9 + OnlinePunctuation(AAssetManager *mgr, + const OnlinePunctuationConfig &config); +#endif + + ~OnlinePunctuation(); + + // Add punctuation and casing to the input text and return it. + std::string AddPunctuationWithCase(const std::string &text) const; + + private: + std::unique_ptr impl_; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_ONLINE_PUNCTUATION_H_ diff --git a/sherpa-onnx/csrc/session.cc b/sherpa-onnx/csrc/session.cc index 50d4abfe..7f6f685e 100644 --- a/sherpa-onnx/csrc/session.cc +++ b/sherpa-onnx/csrc/session.cc @@ -300,4 +300,9 @@ Ort::SessionOptions GetSessionOptions( return GetSessionOptionsImpl(config.num_threads, config.provider); } +Ort::SessionOptions GetSessionOptions( + const OnlinePunctuationModelConfig &config) { + return GetSessionOptionsImpl(config.num_threads, config.provider); +} + } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/session.h b/sherpa-onnx/csrc/session.h index 77da79a7..1e8beb11 100644 --- a/sherpa-onnx/csrc/session.h +++ b/sherpa-onnx/csrc/session.h @@ -12,6 +12,7 @@ #include "sherpa-onnx/csrc/offline-lm-config.h" #include "sherpa-onnx/csrc/offline-model-config.h" #include "sherpa-onnx/csrc/offline-punctuation-model-config.h" +#include "sherpa-onnx/csrc/online-punctuation-model-config.h" #include "sherpa-onnx/csrc/online-lm-config.h" #include "sherpa-onnx/csrc/online-model-config.h" #include "sherpa-onnx/csrc/speaker-embedding-extractor.h" @@ -52,6 +53,9 @@ Ort::SessionOptions GetSessionOptions(const AudioTaggingModelConfig &config); Ort::SessionOptions GetSessionOptions( const OfflinePunctuationModelConfig &config); +Ort::SessionOptions GetSessionOptions( + const OnlinePunctuationModelConfig &config); + } // namespace sherpa_onnx #endif // SHERPA_ONNX_CSRC_SESSION_H_ diff --git a/sherpa-onnx/csrc/sherpa-onnx-online-punctuation.cc b/sherpa-onnx/csrc/sherpa-onnx-online-punctuation.cc new file mode 100644 index 00000000..11f21c36 --- /dev/null +++ b/sherpa-onnx/csrc/sherpa-onnx-online-punctuation.cc @@ -0,0 +1,73 @@ +// sherpa-onnx/csrc/sherpa-onnx-online-punctuation.cc +// +// Copyright (c) 2024 Jian You (jianyou@cisco.com, Cisco Systems) + +#include +#include + +#include // NOLINT + +#include "sherpa-onnx/csrc/online-punctuation.h" +#include "sherpa-onnx/csrc/parse-options.h" + +int main(int32_t argc, char *argv[]) { + const char *kUsageMessage = R"usage( +Add punctuations to the input text. + +The input text can contain English words. + +Usage: + +Please download the model from: +https://huggingface.co/frankyoujian/Edge-Punct-Casing/resolve/main/sherpa-onnx-cnn-bilstm-unigram-bpe-en.7z + +./bin/Release/sherpa-onnx-online-punctuation \ + --cnn-bilstm=/path/to/model.onnx \ + --bpe-vocab=/path/to/bpe.vocab \ + "how are you i am fine thank you" + +The output text should look like below: + "How are you? I am fine. Thank you." +)usage"; + + sherpa_onnx::ParseOptions po(kUsageMessage); + sherpa_onnx::OnlinePunctuationConfig config; + config.Register(&po); + po.Read(argc, argv); + if (po.NumArgs() != 1) { + fprintf(stderr, + "Error: Please provide only 1 positional argument containing the " + "input text.\n\n"); + po.PrintUsage(); + exit(EXIT_FAILURE); + } + + fprintf(stderr, "%s\n", config.ToString().c_str()); + + if (!config.Validate()) { + fprintf(stderr, "Errors in config!\n"); + return -1; + } + + fprintf(stderr, "Creating OnlinePunctuation ...\n"); + sherpa_onnx::OnlinePunctuation punct(config); + fprintf(stderr, "Started\n"); + const auto begin = std::chrono::steady_clock::now(); + + std::string text = po.GetArg(1); + + std::string text_with_punct_case = punct.AddPunctuationWithCase(text); + + const auto end = std::chrono::steady_clock::now(); + fprintf(stderr, "Done\n"); + + float elapsed_seconds = + std::chrono::duration_cast(end - begin) + .count() / + 1000.; + + fprintf(stderr, "Num threads: %d\n", config.model.num_threads); + fprintf(stderr, "Elapsed seconds: %.3f s\n", elapsed_seconds); + fprintf(stderr, "Input text: %s\n", text.c_str()); + fprintf(stderr, "Output text: %s\n", text_with_punct_case.c_str()); +}