Add online punctuation and casing prediction model for English language (#1224)
This commit is contained in:
@@ -138,6 +138,10 @@ list(APPEND sources
|
||||
offline-punctuation-impl.cc
|
||||
offline-punctuation-model-config.cc
|
||||
offline-punctuation.cc
|
||||
online-cnn-bilstm-model.cc
|
||||
online-punctuation-impl.cc
|
||||
online-punctuation-model-config.cc
|
||||
online-punctuation.cc
|
||||
)
|
||||
|
||||
if(SHERPA_ONNX_ENABLE_TTS)
|
||||
@@ -243,6 +247,7 @@ if(SHERPA_ONNX_ENABLE_BINARY)
|
||||
add_executable(sherpa-onnx-offline-language-identification sherpa-onnx-offline-language-identification.cc)
|
||||
add_executable(sherpa-onnx-offline-parallel sherpa-onnx-offline-parallel.cc)
|
||||
add_executable(sherpa-onnx-offline-punctuation sherpa-onnx-offline-punctuation.cc)
|
||||
add_executable(sherpa-onnx-online-punctuation sherpa-onnx-online-punctuation.cc)
|
||||
|
||||
if(SHERPA_ONNX_ENABLE_TTS)
|
||||
add_executable(sherpa-onnx-offline-tts sherpa-onnx-offline-tts.cc)
|
||||
@@ -256,6 +261,7 @@ if(SHERPA_ONNX_ENABLE_BINARY)
|
||||
sherpa-onnx-offline-language-identification
|
||||
sherpa-onnx-offline-parallel
|
||||
sherpa-onnx-offline-punctuation
|
||||
sherpa-onnx-online-punctuation
|
||||
)
|
||||
if(SHERPA_ONNX_ENABLE_TTS)
|
||||
list(APPEND main_exes
|
||||
|
||||
25
sherpa-onnx/csrc/online-cnn-bilstm-model-meta-data.h
Normal file
25
sherpa-onnx/csrc/online-cnn-bilstm-model-meta-data.h
Normal file
@@ -0,0 +1,25 @@
|
||||
// sherpa-onnx/csrc/online-cnn-bilstm-model-meta-data.h
|
||||
//
|
||||
// Copyright (c) 2024 Jian You (jianyou@cisco.com, Cisco Systems)
|
||||
|
||||
#ifndef SHERPA_ONNX_CSRC_ONLINE_CNN_BILSTM_MODEL_META_DATA_H_
|
||||
#define SHERPA_ONNX_CSRC_ONLINE_CNN_BILSTM_MODEL_META_DATA_H_
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
struct OnlineCNNBiLSTMModelMetaData {
|
||||
int32_t comma_id;
|
||||
int32_t period_id;
|
||||
int32_t quest_id;
|
||||
|
||||
int32_t upper_id;
|
||||
int32_t cap_id;
|
||||
int32_t mix_case_id;
|
||||
|
||||
int32_t num_cases;
|
||||
int32_t num_punctuations;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_ONLINE_CNN_BILSTM_MODEL_META_DATA_H_
|
||||
135
sherpa-onnx/csrc/online-cnn-bilstm-model.cc
Normal file
135
sherpa-onnx/csrc/online-cnn-bilstm-model.cc
Normal file
@@ -0,0 +1,135 @@
|
||||
// sherpa-onnx/csrc/online-cnn-bilstm-model.cc
|
||||
//
|
||||
// Copyright (c) 2024 Jian You (jianyou@cisco.com, Cisco Systems)
|
||||
|
||||
#include "sherpa-onnx/csrc/online-cnn-bilstm-model.h"
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
#include "sherpa-onnx/csrc/session.h"
|
||||
#include "sherpa-onnx/csrc/text-utils.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
class OnlineCNNBiLSTMModel::Impl {
|
||||
public:
|
||||
explicit Impl(const OnlinePunctuationModelConfig &config)
|
||||
: config_(config),
|
||||
env_(ORT_LOGGING_LEVEL_ERROR),
|
||||
sess_opts_(GetSessionOptions(config)),
|
||||
allocator_{} {
|
||||
auto buf = ReadFile(config_.cnn_bilstm);
|
||||
Init(buf.data(), buf.size());
|
||||
}
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
Impl(AAssetManager *mgr, const OnlinePunctuationModelConfig &config)
|
||||
: config_(config),
|
||||
env_(ORT_LOGGING_LEVEL_ERROR),
|
||||
sess_opts_(GetSessionOptions(config)),
|
||||
allocator_{} {
|
||||
auto buf = ReadFile(mgr, config_.cnn_bilstm);
|
||||
Init(buf.data(), buf.size());
|
||||
}
|
||||
#endif
|
||||
|
||||
std::pair<Ort::Value, Ort::Value> Forward(Ort::Value token_ids, Ort::Value valid_ids, Ort::Value label_lens) {
|
||||
std::array<Ort::Value, 3> inputs = {std::move(token_ids), std::move(valid_ids), std::move(label_lens)};
|
||||
|
||||
auto ans =
|
||||
sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(),
|
||||
output_names_ptr_.data(), output_names_ptr_.size());
|
||||
return {std::move(ans[0]), std::move(ans[1])};
|
||||
}
|
||||
|
||||
OrtAllocator *Allocator() const { return allocator_; }
|
||||
|
||||
const OnlineCNNBiLSTMModelMetaData &GetModelMetadata() const {
|
||||
return meta_data_;
|
||||
}
|
||||
|
||||
private:
|
||||
void Init(void *model_data, size_t model_data_length) {
|
||||
sess_ = std::make_unique<Ort::Session>(env_, model_data, model_data_length,
|
||||
sess_opts_);
|
||||
|
||||
GetInputNames(sess_.get(), &input_names_, &input_names_ptr_);
|
||||
|
||||
GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_);
|
||||
|
||||
// get meta data
|
||||
Ort::ModelMetadata meta_data = sess_->GetModelMetadata();
|
||||
|
||||
Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
|
||||
|
||||
SHERPA_ONNX_READ_META_DATA(meta_data_.comma_id, "COMMA");
|
||||
SHERPA_ONNX_READ_META_DATA(meta_data_.period_id, "PERIOD");
|
||||
SHERPA_ONNX_READ_META_DATA(meta_data_.quest_id, "QUESTION");
|
||||
|
||||
// assert here, because we will use the constant value
|
||||
assert(meta_data_.comma_id == 1);
|
||||
assert(meta_data_.period_id == 2);
|
||||
assert(meta_data_.quest_id == 3);
|
||||
|
||||
SHERPA_ONNX_READ_META_DATA(meta_data_.upper_id, "UPPER");
|
||||
SHERPA_ONNX_READ_META_DATA(meta_data_.cap_id, "CAP");
|
||||
SHERPA_ONNX_READ_META_DATA(meta_data_.mix_case_id, "MIX_CASE");
|
||||
|
||||
assert(meta_data_.upper_id == 1);
|
||||
assert(meta_data_.cap_id == 2);
|
||||
assert(meta_data_.mix_case_id == 3);
|
||||
|
||||
// output shape is (T', num_cases)
|
||||
meta_data_.num_cases =
|
||||
sess_->GetOutputTypeInfo(0).GetTensorTypeAndShapeInfo().GetShape()[1];
|
||||
meta_data_.num_punctuations =
|
||||
sess_->GetOutputTypeInfo(1).GetTensorTypeAndShapeInfo().GetShape()[1];
|
||||
}
|
||||
|
||||
private:
|
||||
OnlinePunctuationModelConfig config_;
|
||||
Ort::Env env_;
|
||||
Ort::SessionOptions sess_opts_;
|
||||
Ort::AllocatorWithDefaultOptions allocator_;
|
||||
|
||||
std::unique_ptr<Ort::Session> sess_;
|
||||
|
||||
std::vector<std::string> input_names_;
|
||||
std::vector<const char *> input_names_ptr_;
|
||||
|
||||
std::vector<std::string> output_names_;
|
||||
std::vector<const char *> output_names_ptr_;
|
||||
|
||||
OnlineCNNBiLSTMModelMetaData meta_data_;
|
||||
};
|
||||
|
||||
OnlineCNNBiLSTMModel::OnlineCNNBiLSTMModel(
|
||||
const OnlinePunctuationModelConfig &config)
|
||||
: impl_(std::make_unique<Impl>(config)) {}
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
OnlineCNNBiLSTMModel::OnlineCNNBiLSTMModel(
|
||||
AAssetManager *mgr, const OnlinePunctuationModelConfig &config)
|
||||
: impl_(std::make_unique<Impl>(mgr, config)) {}
|
||||
#endif
|
||||
|
||||
OnlineCNNBiLSTMModel::~OnlineCNNBiLSTMModel() = default;
|
||||
|
||||
std::pair<Ort::Value, Ort::Value> OnlineCNNBiLSTMModel::Forward(Ort::Value token_ids,
|
||||
Ort::Value valid_ids,
|
||||
Ort::Value label_lens) const {
|
||||
return impl_->Forward(std::move(token_ids), std::move(valid_ids), std::move(label_lens));
|
||||
}
|
||||
|
||||
OrtAllocator *OnlineCNNBiLSTMModel::Allocator() const {
|
||||
return impl_->Allocator();
|
||||
}
|
||||
|
||||
const OnlineCNNBiLSTMModelMetaData &
|
||||
OnlineCNNBiLSTMModel::GetModelMetadata() const {
|
||||
return impl_->GetModelMetadata();
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
61
sherpa-onnx/csrc/online-cnn-bilstm-model.h
Normal file
61
sherpa-onnx/csrc/online-cnn-bilstm-model.h
Normal file
@@ -0,0 +1,61 @@
|
||||
// sherpa-onnx/csrc/online-cnn-bilstm-model.h
|
||||
//
|
||||
// Copyright (c) 2024 Jian You (jianyou@cisco.com, Cisco Systems)
|
||||
|
||||
#ifndef SHERPA_ONNX_CSRC_ONLINE_CNN_BILSTM_MODEL_H_
|
||||
#define SHERPA_ONNX_CSRC_ONLINE_CNN_BILSTM_MODEL_H_
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
#include "android/asset_manager.h"
|
||||
#include "android/asset_manager_jni.h"
|
||||
#endif
|
||||
|
||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||
#include "sherpa-onnx/csrc/online-cnn-bilstm-model-meta-data.h"
|
||||
#include "sherpa-onnx/csrc/online-punctuation-model-config.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
/** This class implements
|
||||
* https://github.com/frankyoujian/Edge-Punct-Casing/blob/main/onnx_decode_sentence.py
|
||||
*/
|
||||
class OnlineCNNBiLSTMModel {
|
||||
public:
|
||||
explicit OnlineCNNBiLSTMModel(
|
||||
const OnlinePunctuationModelConfig &config);
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
OnlineCNNBiLSTMModel(AAssetManager *mgr,
|
||||
const OnlinePunctuationModelConfig &config);
|
||||
#endif
|
||||
|
||||
~OnlineCNNBiLSTMModel();
|
||||
|
||||
/** Run the forward method of the model.
|
||||
*
|
||||
* @param token_ids A tensor of shape (N, T) of dtype int32.
|
||||
* @param valid_ids A tensor of shape (N, T) of dtype int32.
|
||||
* @param label_lens A tensor of shape (N) of dtype int32.
|
||||
*
|
||||
* @return Return a pair of tensors
|
||||
* - case_logits: A 2-D tensor of shape (T', num_cases).
|
||||
* - punct_logits: A 2-D tensor of shape (T', num_puncts).
|
||||
*/
|
||||
std::pair<Ort::Value, Ort::Value> Forward(Ort::Value token_ids, Ort::Value valid_ids, Ort::Value label_lens) const;
|
||||
|
||||
/** Return an allocator for allocating memory
|
||||
*/
|
||||
OrtAllocator *Allocator() const;
|
||||
|
||||
const OnlineCNNBiLSTMModelMetaData &GetModelMetadata() const;
|
||||
|
||||
private:
|
||||
class Impl;
|
||||
std::unique_ptr<Impl> impl_;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_ONLINE_CNN_BILSTM_MODEL_H_
|
||||
268
sherpa-onnx/csrc/online-punctuation-cnn-bilstm-impl.h
Normal file
268
sherpa-onnx/csrc/online-punctuation-cnn-bilstm-impl.h
Normal file
@@ -0,0 +1,268 @@
|
||||
// sherpa-onnx/csrc/online-punctuation-cnn-bilstm-impl.h
|
||||
//
|
||||
// Copyright (c) 2024 Jian You (jianyou@cisco.com, Cisco Systems)
|
||||
|
||||
#ifndef SHERPA_ONNX_CSRC_ONLINE_PUNCTUATION_CNN_BILSTM_IMPL_H_
|
||||
#define SHERPA_ONNX_CSRC_ONLINE_PUNCTUATION_CNN_BILSTM_IMPL_H_
|
||||
|
||||
#include <math.h>
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
#include "android/asset_manager.h"
|
||||
#include "android/asset_manager_jni.h"
|
||||
#endif
|
||||
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/math.h"
|
||||
#include "sherpa-onnx/csrc/online-cnn-bilstm-model.h"
|
||||
#include "sherpa-onnx/csrc/online-punctuation-impl.h"
|
||||
#include "sherpa-onnx/csrc/online-punctuation.h"
|
||||
#include "sherpa-onnx/csrc/online-cnn-bilstm-model-meta-data.h"
|
||||
#include "sherpa-onnx/csrc/text-utils.h"
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
#include "ssentencepiece/csrc/ssentencepiece.h"
|
||||
#include <chrono> // NOLINT
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
static const int32_t kMaxSeqLen = 200;
|
||||
|
||||
class OnlinePunctuationCNNBiLSTMImpl : public OnlinePunctuationImpl {
|
||||
public:
|
||||
explicit OnlinePunctuationCNNBiLSTMImpl(
|
||||
const OnlinePunctuationConfig &config)
|
||||
: config_(config), model_(config.model) {
|
||||
if (!config_.model.bpe_vocab.empty()) {
|
||||
bpe_encoder_ = std::make_unique<ssentencepiece::Ssentencepiece>(
|
||||
config_.model.bpe_vocab);
|
||||
}
|
||||
}
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
OnlinePunctuationCNNBiLSTMImpl(AAssetManager *mgr,
|
||||
const OnlinePunctuationConfig &config)
|
||||
: config_(config), model_(mgr, config.model) {
|
||||
if (!config_.model.bpe_vocab.empty()) {
|
||||
auto buf = ReadFile(mgr, config_.model.bpe_vocab);
|
||||
std::istringstream iss(std::string(buf.begin(), buf.end()));
|
||||
bpe_encoder_ = std::make_unique<ssentencepiece::Ssentencepiece>(iss);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
std::string AddPunctuationWithCase(const std::string &text) const override {
|
||||
if (text.empty()) {
|
||||
return {};
|
||||
}
|
||||
|
||||
std::vector<int32_t> tokens_list; // N * kMaxSeqLen
|
||||
std::vector<int32_t> valids_list; // N * kMaxSeqLen
|
||||
std::vector<int32_t> label_len_list; // N
|
||||
|
||||
EncodeSentences(text, tokens_list, valids_list, label_len_list);
|
||||
|
||||
const auto &meta_data = model_.GetModelMetadata();
|
||||
|
||||
auto memory_info =
|
||||
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault);
|
||||
|
||||
int32_t n = label_len_list.size();
|
||||
|
||||
std::array<int64_t, 2> token_ids_shape = {n, kMaxSeqLen};
|
||||
Ort::Value token_ids = Ort::Value::CreateTensor(memory_info, tokens_list.data(), tokens_list.size(),
|
||||
token_ids_shape.data(), token_ids_shape.size());
|
||||
|
||||
std::array<int64_t, 2> valid_ids_shape = {n, kMaxSeqLen};
|
||||
Ort::Value valid_ids = Ort::Value::CreateTensor(memory_info, valids_list.data(), valids_list.size(),
|
||||
valid_ids_shape.data(), valid_ids_shape.size());
|
||||
|
||||
std::array<int64_t, 1> label_len_shape = {n};
|
||||
Ort::Value label_len = Ort::Value::CreateTensor(memory_info, label_len_list.data(), label_len_list.size(),
|
||||
label_len_shape.data(), label_len_shape.size());
|
||||
|
||||
auto pair = model_.Forward(std::move(token_ids), std::move(valid_ids), std::move(label_len));
|
||||
|
||||
std::vector<int32_t> case_pred;
|
||||
std::vector<int32_t> punct_pred;
|
||||
const float* active_case_logits = pair.first.GetTensorData<float>();
|
||||
const float* active_punct_logits = pair.second.GetTensorData<float>();
|
||||
std::vector<int64_t> case_logits_shape = pair.first.GetTensorTypeAndShapeInfo().GetShape();
|
||||
|
||||
for (int32_t i = 0; i < case_logits_shape[0]; ++i) {
|
||||
const float* p_cur_case = active_case_logits + i * meta_data.num_cases;
|
||||
auto index_case = static_cast<int32_t>(std::distance(
|
||||
p_cur_case, std::max_element(p_cur_case, p_cur_case + meta_data.num_cases)));
|
||||
case_pred.push_back(index_case);
|
||||
|
||||
const float* p_cur_punct = active_punct_logits + i * meta_data.num_punctuations;
|
||||
auto index_punct = static_cast<int32_t>(std::distance(
|
||||
p_cur_punct, std::max_element(p_cur_punct, p_cur_punct + meta_data.num_punctuations)));
|
||||
punct_pred.push_back(index_punct);
|
||||
}
|
||||
|
||||
std::string ans = DecodeSentences(text, case_pred, punct_pred);
|
||||
|
||||
return ans;
|
||||
}
|
||||
|
||||
private:
|
||||
void EncodeSentences(const std::string& text,
|
||||
std::vector<int32_t>& tokens_list,
|
||||
std::vector<int32_t>& valids_list,
|
||||
std::vector<int32_t>& label_len_list) const {
|
||||
std::vector<int32_t> tokens;
|
||||
std::vector<int32_t> valids;
|
||||
int32_t label_len = 0;
|
||||
|
||||
tokens.push_back(1); // hardcode 1 now, 1 - <s>
|
||||
valids.push_back(1);
|
||||
|
||||
std::stringstream ss(text);
|
||||
std::string word;
|
||||
while (ss >> word) {
|
||||
std::vector<int32_t> word_tokens;
|
||||
bpe_encoder_->Encode(word, &word_tokens);
|
||||
|
||||
int32_t seq_len = tokens.size() + word_tokens.size();
|
||||
if (seq_len > kMaxSeqLen - 1) {
|
||||
tokens.push_back(2); // hardcode 2 now, 2 - </s>
|
||||
valids.push_back(1);
|
||||
|
||||
label_len = std::count(valids.begin(), valids.end(), 1);
|
||||
|
||||
if (tokens.size() < kMaxSeqLen) {
|
||||
tokens.resize(kMaxSeqLen, 0);
|
||||
valids.resize(kMaxSeqLen, 0);
|
||||
}
|
||||
|
||||
assert(tokens.size() == kMaxSeqLen);
|
||||
assert(valids.size() == kMaxSeqLen);
|
||||
|
||||
tokens_list.insert(tokens_list.end(), tokens.begin(), tokens.end());
|
||||
valids_list.insert(valids_list.end(), valids.begin(), valids.end());
|
||||
label_len_list.push_back(label_len);
|
||||
|
||||
std::vector<int32_t>().swap(tokens);
|
||||
std::vector<int32_t>().swap(valids);
|
||||
label_len = 0;
|
||||
tokens.push_back(1); // hardcode 1 now, 1 - <s>
|
||||
valids.push_back(1);
|
||||
}
|
||||
|
||||
tokens.insert(tokens.end(), word_tokens.begin(), word_tokens.end());
|
||||
valids.push_back(1); // only the first sub word is valid
|
||||
int32_t remaining_size = static_cast<int32_t>(word_tokens.size()) - 1;
|
||||
if (remaining_size > 0) {
|
||||
int32_t valids_cur_size = static_cast<int32_t>(valids.size());
|
||||
valids.resize(valids_cur_size + remaining_size, 0);
|
||||
}
|
||||
}
|
||||
|
||||
if (tokens.size() > 0) {
|
||||
tokens.push_back(2); // hardcode 2 now, 2 - </s>
|
||||
valids.push_back(1);
|
||||
|
||||
label_len = std::count(valids.begin(), valids.end(), 1);
|
||||
|
||||
if (tokens.size() < kMaxSeqLen) {
|
||||
tokens.resize(kMaxSeqLen, 0);
|
||||
valids.resize(kMaxSeqLen, 0);
|
||||
}
|
||||
|
||||
assert(tokens.size() == kMaxSeqLen);
|
||||
assert(valids.size() == kMaxSeqLen);
|
||||
|
||||
tokens_list.insert(tokens_list.end(), tokens.begin(), tokens.end());
|
||||
valids_list.insert(valids_list.end(), valids.begin(), valids.end());
|
||||
label_len_list.push_back(label_len);
|
||||
}
|
||||
}
|
||||
|
||||
std::string DecodeSentences(const std::string& raw_text,
|
||||
const std::vector<int32_t>& case_pred,
|
||||
const std::vector<int32_t>& punct_pred) const {
|
||||
std::string result_text;
|
||||
std::istringstream iss(raw_text);
|
||||
std::vector<std::string> words;
|
||||
std::string word;
|
||||
|
||||
while (iss >> word) {
|
||||
words.emplace_back(word);
|
||||
}
|
||||
|
||||
assert(words.size() == case_pred.size());
|
||||
assert(words.size() == punct_pred.size());
|
||||
|
||||
for (int32_t i = 0; i < words.size(); ++i) {
|
||||
std::string prefix = ((i != 0) ? " " : "");
|
||||
result_text += prefix;
|
||||
switch (case_pred[i]) {
|
||||
case 1: // upper
|
||||
{
|
||||
std::transform(words[i].begin(), words[i].end(), words[i].begin(), [](auto c){ return std::toupper(c); });
|
||||
result_text += words[i];
|
||||
break;
|
||||
}
|
||||
case 2: // cap
|
||||
{
|
||||
words[i][0] = std::toupper(words[i][0]);
|
||||
result_text += words[i];
|
||||
break;
|
||||
}
|
||||
case 3: // mix case
|
||||
{
|
||||
// TODO:
|
||||
// Need to add a map containing supported mix case words so that we can fetch the predicted word from the map
|
||||
// e.g. mcdonald's -> McDonald's
|
||||
result_text += words[i];
|
||||
break;
|
||||
}
|
||||
default:
|
||||
{
|
||||
result_text += words[i];
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
std::string suffix;
|
||||
switch (punct_pred[i]) {
|
||||
case 1: // comma
|
||||
{
|
||||
suffix = ",";
|
||||
break;
|
||||
}
|
||||
case 2: // period
|
||||
{
|
||||
suffix = ".";
|
||||
break;
|
||||
}
|
||||
case 3: // question
|
||||
{
|
||||
suffix = "?";
|
||||
break;
|
||||
}
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
result_text += suffix;
|
||||
}
|
||||
|
||||
return result_text;
|
||||
}
|
||||
|
||||
private:
|
||||
OnlinePunctuationConfig config_;
|
||||
OnlineCNNBiLSTMModel model_;
|
||||
std::unique_ptr<ssentencepiece::Ssentencepiece> bpe_encoder_;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_ONLINE_PUNCTUATION_CNN_BILSTM_IMPL_H_
|
||||
39
sherpa-onnx/csrc/online-punctuation-impl.cc
Normal file
39
sherpa-onnx/csrc/online-punctuation-impl.cc
Normal file
@@ -0,0 +1,39 @@
|
||||
// sherpa-onnx/csrc/online-punctuation-impl.cc
|
||||
//
|
||||
// Copyright (c) 2024 Jian You (jianyou@cisco.com, Cisco Systems)
|
||||
|
||||
#include "sherpa-onnx/csrc/online-punctuation-impl.h"
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
#include "android/asset_manager.h"
|
||||
#include "android/asset_manager_jni.h"
|
||||
#endif
|
||||
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/online-punctuation-cnn-bilstm-impl.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
std::unique_ptr<OnlinePunctuationImpl> OnlinePunctuationImpl::Create(
|
||||
const OnlinePunctuationConfig &config) {
|
||||
if (!config.model.cnn_bilstm.empty() && !config.model.bpe_vocab.empty()) {
|
||||
return std::make_unique<OnlinePunctuationCNNBiLSTMImpl>(config);
|
||||
}
|
||||
|
||||
SHERPA_ONNX_LOGE("Please specify a punctuation model and bpe vocab! Return a null pointer");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
std::unique_ptr<OnlinePunctuationImpl> OnlinePunctuationImpl::Create(
|
||||
AAssetManager *mgr, const OnlinePunctuationConfig &config) {
|
||||
if (!config.model.cnn_bilstm.empty() && !config.model.bpe_vocab.empty()) {
|
||||
return std::make_unique<OnlinePunctuationCNNBiLSTMImpl>(mgr, config);
|
||||
}
|
||||
|
||||
SHERPA_ONNX_LOGE("Please specify a punctuation model and bpe vocab! Return a null pointer");
|
||||
return nullptr;
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
37
sherpa-onnx/csrc/online-punctuation-impl.h
Normal file
37
sherpa-onnx/csrc/online-punctuation-impl.h
Normal file
@@ -0,0 +1,37 @@
|
||||
// sherpa-onnx/csrc/online-punctuation-impl.h
|
||||
//
|
||||
// Copyright (c) 2024 Jian You (jianyou@cisco.com, Cisco Systems)
|
||||
|
||||
#ifndef SHERPA_ONNX_CSRC_ONLINE_PUNCTUATION_IMPL_H_
|
||||
#define SHERPA_ONNX_CSRC_ONLINE_PUNCTUATION_IMPL_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#if __ANDROID_API__ >= 9
|
||||
#include "android/asset_manager.h"
|
||||
#include "android/asset_manager_jni.h"
|
||||
#endif
|
||||
|
||||
#include "sherpa-onnx/csrc/online-punctuation.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
class OnlinePunctuationImpl {
|
||||
public:
|
||||
virtual ~OnlinePunctuationImpl() = default;
|
||||
|
||||
static std::unique_ptr<OnlinePunctuationImpl> Create(
|
||||
const OnlinePunctuationConfig &config);
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
static std::unique_ptr<OnlinePunctuationImpl> Create(
|
||||
AAssetManager *mgr, const OnlinePunctuationConfig &config);
|
||||
#endif
|
||||
|
||||
virtual std::string AddPunctuationWithCase(const std::string &text) const = 0;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_ONLINE_PUNCTUATION_IMPL_H_
|
||||
68
sherpa-onnx/csrc/online-punctuation-model-config.cc
Normal file
68
sherpa-onnx/csrc/online-punctuation-model-config.cc
Normal file
@@ -0,0 +1,68 @@
|
||||
// sherpa-onnx/csrc/online-punctuation-model-config.cc
|
||||
//
|
||||
// Copyright (c) 2024 Jian You (jianyou@cisco.com, Cisco Systems)
|
||||
|
||||
#include "sherpa-onnx/csrc/online-punctuation-model-config.h"
|
||||
|
||||
#include "sherpa-onnx/csrc/file-utils.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
void OnlinePunctuationModelConfig::Register(ParseOptions *po) {
|
||||
po->Register("cnn-bilstm", &cnn_bilstm,
|
||||
"Path to the light-weight CNN-BiLSTM model");
|
||||
|
||||
po->Register("bpe-vocab", &bpe_vocab,
|
||||
"Path to the bpe vocab file");
|
||||
|
||||
po->Register("num-threads", &num_threads,
|
||||
"Number of threads to run the neural network");
|
||||
|
||||
po->Register("debug", &debug,
|
||||
"true to print model information while loading it.");
|
||||
|
||||
po->Register("provider", &provider,
|
||||
"Specify a provider to use: cpu, cuda, coreml");
|
||||
}
|
||||
|
||||
bool OnlinePunctuationModelConfig::Validate() const {
|
||||
if (cnn_bilstm.empty()) {
|
||||
SHERPA_ONNX_LOGE("Please provide --cnn-bilstm");
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!FileExists(cnn_bilstm)) {
|
||||
SHERPA_ONNX_LOGE("--cnn-bilstm '%s' does not exist",
|
||||
cnn_bilstm.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
if (bpe_vocab.empty()) {
|
||||
SHERPA_ONNX_LOGE("Please provide --bpe-vocab");
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!FileExists(bpe_vocab)) {
|
||||
SHERPA_ONNX_LOGE("--bpe-vocab '%s' does not exist",
|
||||
bpe_vocab.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
std::string OnlinePunctuationModelConfig::ToString() const {
|
||||
std::ostringstream os;
|
||||
|
||||
os << "OnlinePunctuationModelConfig(";
|
||||
os << "cnn_bilstm=\"" << cnn_bilstm << "\", ";
|
||||
os << "bpe_vocab=\"" << bpe_vocab << "\", ";
|
||||
os << "num_threads=" << num_threads << ", ";
|
||||
os << "debug=" << (debug ? "True" : "False") << ", ";
|
||||
os << "provider=\"" << provider << "\")";
|
||||
|
||||
return os.str();
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
42
sherpa-onnx/csrc/online-punctuation-model-config.h
Normal file
42
sherpa-onnx/csrc/online-punctuation-model-config.h
Normal file
@@ -0,0 +1,42 @@
|
||||
// sherpa-onnx/csrc/online-punctuation-model-config.h
|
||||
//
|
||||
// Copyright (c) 2024 Jian You (jianyou@cisco.com, Cisco Systems)
|
||||
|
||||
#ifndef SHERPA_ONNX_CSRC_ONLINE_PUNCTUATION_MODEL_CONFIG_H_
|
||||
#define SHERPA_ONNX_CSRC_ONLINE_PUNCTUATION_MODEL_CONFIG_H_
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "sherpa-onnx/csrc/parse-options.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
struct OnlinePunctuationModelConfig {
|
||||
std::string cnn_bilstm;
|
||||
std::string bpe_vocab;
|
||||
|
||||
int32_t num_threads = 1;
|
||||
bool debug = false;
|
||||
std::string provider = "cpu";
|
||||
|
||||
OnlinePunctuationModelConfig() = default;
|
||||
|
||||
OnlinePunctuationModelConfig(const std::string &cnn_bilstm,
|
||||
const std::string &bpe_vocab,
|
||||
int32_t num_threads, bool debug,
|
||||
const std::string &provider)
|
||||
: cnn_bilstm(cnn_bilstm),
|
||||
bpe_vocab(bpe_vocab),
|
||||
num_threads(num_threads),
|
||||
debug(debug),
|
||||
provider(provider) {}
|
||||
|
||||
void Register(ParseOptions *po);
|
||||
bool Validate() const;
|
||||
|
||||
std::string ToString() const;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_ONLINE_PUNCTUATION_MODEL_CONFIG_H_
|
||||
53
sherpa-onnx/csrc/online-punctuation.cc
Normal file
53
sherpa-onnx/csrc/online-punctuation.cc
Normal file
@@ -0,0 +1,53 @@
|
||||
// sherpa-onnx/csrc/online-punctuation.cc
|
||||
//
|
||||
// Copyright (c) 2024 Jian You (jianyou@cisco.com, Cisco Systems)
|
||||
|
||||
#include "sherpa-onnx/csrc/online-punctuation.h"
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
#include "android/asset_manager.h"
|
||||
#include "android/asset_manager_jni.h"
|
||||
#endif
|
||||
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/online-punctuation-impl.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
void OnlinePunctuationConfig::Register(ParseOptions *po) {
|
||||
model.Register(po);
|
||||
}
|
||||
|
||||
bool OnlinePunctuationConfig::Validate() const {
|
||||
if (!model.Validate()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
std::string OnlinePunctuationConfig::ToString() const {
|
||||
std::ostringstream os;
|
||||
|
||||
os << "OnlinePunctuationConfig(";
|
||||
os << "model=" << model.ToString() << ")";
|
||||
|
||||
return os.str();
|
||||
}
|
||||
|
||||
OnlinePunctuation::OnlinePunctuation(const OnlinePunctuationConfig &config)
|
||||
: impl_(OnlinePunctuationImpl::Create(config)) {}
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
OnlinePunctuation::OnlinePunctuation(AAssetManager *mgr,
|
||||
const OnlinePunctuationConfig &config)
|
||||
: impl_(OnlinePunctuationImpl::Create(mgr, config)) {}
|
||||
#endif
|
||||
|
||||
OnlinePunctuation::~OnlinePunctuation() = default;
|
||||
|
||||
std::string OnlinePunctuation::AddPunctuationWithCase(const std::string &text) const {
|
||||
return impl_->AddPunctuationWithCase(text);
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
58
sherpa-onnx/csrc/online-punctuation.h
Normal file
58
sherpa-onnx/csrc/online-punctuation.h
Normal file
@@ -0,0 +1,58 @@
|
||||
// sherpa-onnx/csrc/online-punctuation.h
|
||||
//
|
||||
// Copyright (c) 2024 Jian You (jianyou@cisco.com, Cisco Systems)
|
||||
|
||||
#ifndef SHERPA_ONNX_CSRC_ONLINE_PUNCTUATION_H_
|
||||
#define SHERPA_ONNX_CSRC_ONLINE_PUNCTUATION_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
#include "android/asset_manager.h"
|
||||
#include "android/asset_manager_jni.h"
|
||||
#endif
|
||||
|
||||
#include "sherpa-onnx/csrc/online-punctuation-model-config.h"
|
||||
#include "sherpa-onnx/csrc/parse-options.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
struct OnlinePunctuationConfig {
|
||||
OnlinePunctuationModelConfig model;
|
||||
|
||||
OnlinePunctuationConfig() = default;
|
||||
|
||||
explicit OnlinePunctuationConfig(const OnlinePunctuationModelConfig &model)
|
||||
: model(model) {}
|
||||
|
||||
void Register(ParseOptions *po);
|
||||
bool Validate() const;
|
||||
|
||||
std::string ToString() const;
|
||||
};
|
||||
|
||||
class OnlinePunctuationImpl;
|
||||
|
||||
class OnlinePunctuation {
|
||||
public:
|
||||
explicit OnlinePunctuation(const OnlinePunctuationConfig &config);
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
OnlinePunctuation(AAssetManager *mgr,
|
||||
const OnlinePunctuationConfig &config);
|
||||
#endif
|
||||
|
||||
~OnlinePunctuation();
|
||||
|
||||
// Add punctuation and casing to the input text and return it.
|
||||
std::string AddPunctuationWithCase(const std::string &text) const;
|
||||
|
||||
private:
|
||||
std::unique_ptr<OnlinePunctuationImpl> impl_;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_ONLINE_PUNCTUATION_H_
|
||||
@@ -300,4 +300,9 @@ Ort::SessionOptions GetSessionOptions(
|
||||
return GetSessionOptionsImpl(config.num_threads, config.provider);
|
||||
}
|
||||
|
||||
Ort::SessionOptions GetSessionOptions(
|
||||
const OnlinePunctuationModelConfig &config) {
|
||||
return GetSessionOptionsImpl(config.num_threads, config.provider);
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
#include "sherpa-onnx/csrc/offline-lm-config.h"
|
||||
#include "sherpa-onnx/csrc/offline-model-config.h"
|
||||
#include "sherpa-onnx/csrc/offline-punctuation-model-config.h"
|
||||
#include "sherpa-onnx/csrc/online-punctuation-model-config.h"
|
||||
#include "sherpa-onnx/csrc/online-lm-config.h"
|
||||
#include "sherpa-onnx/csrc/online-model-config.h"
|
||||
#include "sherpa-onnx/csrc/speaker-embedding-extractor.h"
|
||||
@@ -52,6 +53,9 @@ Ort::SessionOptions GetSessionOptions(const AudioTaggingModelConfig &config);
|
||||
Ort::SessionOptions GetSessionOptions(
|
||||
const OfflinePunctuationModelConfig &config);
|
||||
|
||||
Ort::SessionOptions GetSessionOptions(
|
||||
const OnlinePunctuationModelConfig &config);
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_SESSION_H_
|
||||
|
||||
73
sherpa-onnx/csrc/sherpa-onnx-online-punctuation.cc
Normal file
73
sherpa-onnx/csrc/sherpa-onnx-online-punctuation.cc
Normal file
@@ -0,0 +1,73 @@
|
||||
// sherpa-onnx/csrc/sherpa-onnx-online-punctuation.cc
|
||||
//
|
||||
// Copyright (c) 2024 Jian You (jianyou@cisco.com, Cisco Systems)
|
||||
|
||||
#include <stdio.h>
|
||||
#include <iostream>
|
||||
|
||||
#include <chrono> // NOLINT
|
||||
|
||||
#include "sherpa-onnx/csrc/online-punctuation.h"
|
||||
#include "sherpa-onnx/csrc/parse-options.h"
|
||||
|
||||
int main(int32_t argc, char *argv[]) {
|
||||
const char *kUsageMessage = R"usage(
|
||||
Add punctuations to the input text.
|
||||
|
||||
The input text can contain English words.
|
||||
|
||||
Usage:
|
||||
|
||||
Please download the model from:
|
||||
https://huggingface.co/frankyoujian/Edge-Punct-Casing/resolve/main/sherpa-onnx-cnn-bilstm-unigram-bpe-en.7z
|
||||
|
||||
./bin/Release/sherpa-onnx-online-punctuation \
|
||||
--cnn-bilstm=/path/to/model.onnx \
|
||||
--bpe-vocab=/path/to/bpe.vocab \
|
||||
"how are you i am fine thank you"
|
||||
|
||||
The output text should look like below:
|
||||
"How are you? I am fine. Thank you."
|
||||
)usage";
|
||||
|
||||
sherpa_onnx::ParseOptions po(kUsageMessage);
|
||||
sherpa_onnx::OnlinePunctuationConfig config;
|
||||
config.Register(&po);
|
||||
po.Read(argc, argv);
|
||||
if (po.NumArgs() != 1) {
|
||||
fprintf(stderr,
|
||||
"Error: Please provide only 1 positional argument containing the "
|
||||
"input text.\n\n");
|
||||
po.PrintUsage();
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
|
||||
fprintf(stderr, "%s\n", config.ToString().c_str());
|
||||
|
||||
if (!config.Validate()) {
|
||||
fprintf(stderr, "Errors in config!\n");
|
||||
return -1;
|
||||
}
|
||||
|
||||
fprintf(stderr, "Creating OnlinePunctuation ...\n");
|
||||
sherpa_onnx::OnlinePunctuation punct(config);
|
||||
fprintf(stderr, "Started\n");
|
||||
const auto begin = std::chrono::steady_clock::now();
|
||||
|
||||
std::string text = po.GetArg(1);
|
||||
|
||||
std::string text_with_punct_case = punct.AddPunctuationWithCase(text);
|
||||
|
||||
const auto end = std::chrono::steady_clock::now();
|
||||
fprintf(stderr, "Done\n");
|
||||
|
||||
float elapsed_seconds =
|
||||
std::chrono::duration_cast<std::chrono::milliseconds>(end - begin)
|
||||
.count() /
|
||||
1000.;
|
||||
|
||||
fprintf(stderr, "Num threads: %d\n", config.model.num_threads);
|
||||
fprintf(stderr, "Elapsed seconds: %.3f s\n", elapsed_seconds);
|
||||
fprintf(stderr, "Input text: %s\n", text.c_str());
|
||||
fprintf(stderr, "Output text: %s\n", text_with_punct_case.c_str());
|
||||
}
|
||||
Reference in New Issue
Block a user