Fix style issues for online punctuation source files (#1225)
This commit is contained in:
@@ -58,6 +58,7 @@ def get_binaries():
|
|||||||
"sherpa-onnx-offline-tts",
|
"sherpa-onnx-offline-tts",
|
||||||
"sherpa-onnx-offline-tts-play",
|
"sherpa-onnx-offline-tts-play",
|
||||||
"sherpa-onnx-offline-websocket-server",
|
"sherpa-onnx-offline-websocket-server",
|
||||||
|
"sherpa-onnx-online-punctuation",
|
||||||
"sherpa-onnx-online-websocket-client",
|
"sherpa-onnx-online-websocket-client",
|
||||||
"sherpa-onnx-online-websocket-server",
|
"sherpa-onnx-online-websocket-server",
|
||||||
"sherpa-onnx-vad-microphone",
|
"sherpa-onnx-vad-microphone",
|
||||||
|
|||||||
@@ -35,8 +35,11 @@ class OnlineCNNBiLSTMModel::Impl {
|
|||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
std::pair<Ort::Value, Ort::Value> Forward(Ort::Value token_ids, Ort::Value valid_ids, Ort::Value label_lens) {
|
std::pair<Ort::Value, Ort::Value> Forward(Ort::Value token_ids,
|
||||||
std::array<Ort::Value, 3> inputs = {std::move(token_ids), std::move(valid_ids), std::move(label_lens)};
|
Ort::Value valid_ids,
|
||||||
|
Ort::Value label_lens) {
|
||||||
|
std::array<Ort::Value, 3> inputs = {
|
||||||
|
std::move(token_ids), std::move(valid_ids), std::move(label_lens)};
|
||||||
|
|
||||||
auto ans =
|
auto ans =
|
||||||
sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(),
|
sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(),
|
||||||
@@ -117,18 +120,18 @@ OnlineCNNBiLSTMModel::OnlineCNNBiLSTMModel(
|
|||||||
|
|
||||||
OnlineCNNBiLSTMModel::~OnlineCNNBiLSTMModel() = default;
|
OnlineCNNBiLSTMModel::~OnlineCNNBiLSTMModel() = default;
|
||||||
|
|
||||||
std::pair<Ort::Value, Ort::Value> OnlineCNNBiLSTMModel::Forward(Ort::Value token_ids,
|
std::pair<Ort::Value, Ort::Value> OnlineCNNBiLSTMModel::Forward(
|
||||||
Ort::Value valid_ids,
|
Ort::Value token_ids, Ort::Value valid_ids, Ort::Value label_lens) const {
|
||||||
Ort::Value label_lens) const {
|
return impl_->Forward(std::move(token_ids), std::move(valid_ids),
|
||||||
return impl_->Forward(std::move(token_ids), std::move(valid_ids), std::move(label_lens));
|
std::move(label_lens));
|
||||||
}
|
}
|
||||||
|
|
||||||
OrtAllocator *OnlineCNNBiLSTMModel::Allocator() const {
|
OrtAllocator *OnlineCNNBiLSTMModel::Allocator() const {
|
||||||
return impl_->Allocator();
|
return impl_->Allocator();
|
||||||
}
|
}
|
||||||
|
|
||||||
const OnlineCNNBiLSTMModelMetaData &
|
const OnlineCNNBiLSTMModelMetaData &OnlineCNNBiLSTMModel::GetModelMetadata()
|
||||||
OnlineCNNBiLSTMModel::GetModelMetadata() const {
|
const {
|
||||||
return impl_->GetModelMetadata();
|
return impl_->GetModelMetadata();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -23,8 +23,7 @@ namespace sherpa_onnx {
|
|||||||
*/
|
*/
|
||||||
class OnlineCNNBiLSTMModel {
|
class OnlineCNNBiLSTMModel {
|
||||||
public:
|
public:
|
||||||
explicit OnlineCNNBiLSTMModel(
|
explicit OnlineCNNBiLSTMModel(const OnlinePunctuationModelConfig &config);
|
||||||
const OnlinePunctuationModelConfig &config);
|
|
||||||
|
|
||||||
#if __ANDROID_API__ >= 9
|
#if __ANDROID_API__ >= 9
|
||||||
OnlineCNNBiLSTMModel(AAssetManager *mgr,
|
OnlineCNNBiLSTMModel(AAssetManager *mgr,
|
||||||
@@ -43,7 +42,9 @@ class OnlineCNNBiLSTMModel {
|
|||||||
* - case_logits: A 2-D tensor of shape (T', num_cases).
|
* - case_logits: A 2-D tensor of shape (T', num_cases).
|
||||||
* - punct_logits: A 2-D tensor of shape (T', num_puncts).
|
* - punct_logits: A 2-D tensor of shape (T', num_puncts).
|
||||||
*/
|
*/
|
||||||
std::pair<Ort::Value, Ort::Value> Forward(Ort::Value token_ids, Ort::Value valid_ids, Ort::Value label_lens) const;
|
std::pair<Ort::Value, Ort::Value> Forward(Ort::Value token_ids,
|
||||||
|
Ort::Value valid_ids,
|
||||||
|
Ort::Value label_lens) const;
|
||||||
|
|
||||||
/** Return an allocator for allocating memory
|
/** Return an allocator for allocating memory
|
||||||
*/
|
*/
|
||||||
|
|||||||
@@ -7,27 +7,28 @@
|
|||||||
|
|
||||||
#include <math.h>
|
#include <math.h>
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <algorithm>
|
|
||||||
|
|
||||||
#if __ANDROID_API__ >= 9
|
#if __ANDROID_API__ >= 9
|
||||||
#include "android/asset_manager.h"
|
#include "android/asset_manager.h"
|
||||||
#include "android/asset_manager_jni.h"
|
#include "android/asset_manager_jni.h"
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#include <chrono> // NOLINT
|
||||||
|
|
||||||
#include "sherpa-onnx/csrc/macros.h"
|
#include "sherpa-onnx/csrc/macros.h"
|
||||||
#include "sherpa-onnx/csrc/math.h"
|
#include "sherpa-onnx/csrc/math.h"
|
||||||
|
#include "sherpa-onnx/csrc/online-cnn-bilstm-model-meta-data.h"
|
||||||
#include "sherpa-onnx/csrc/online-cnn-bilstm-model.h"
|
#include "sherpa-onnx/csrc/online-cnn-bilstm-model.h"
|
||||||
#include "sherpa-onnx/csrc/online-punctuation-impl.h"
|
#include "sherpa-onnx/csrc/online-punctuation-impl.h"
|
||||||
#include "sherpa-onnx/csrc/online-punctuation.h"
|
#include "sherpa-onnx/csrc/online-punctuation.h"
|
||||||
#include "sherpa-onnx/csrc/online-cnn-bilstm-model-meta-data.h"
|
|
||||||
#include "sherpa-onnx/csrc/text-utils.h"
|
|
||||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||||
|
#include "sherpa-onnx/csrc/text-utils.h"
|
||||||
#include "ssentencepiece/csrc/ssentencepiece.h"
|
#include "ssentencepiece/csrc/ssentencepiece.h"
|
||||||
#include <chrono> // NOLINT
|
|
||||||
|
|
||||||
namespace sherpa_onnx {
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
@@ -35,8 +36,7 @@ static const int32_t kMaxSeqLen = 200;
|
|||||||
|
|
||||||
class OnlinePunctuationCNNBiLSTMImpl : public OnlinePunctuationImpl {
|
class OnlinePunctuationCNNBiLSTMImpl : public OnlinePunctuationImpl {
|
||||||
public:
|
public:
|
||||||
explicit OnlinePunctuationCNNBiLSTMImpl(
|
explicit OnlinePunctuationCNNBiLSTMImpl(const OnlinePunctuationConfig &config)
|
||||||
const OnlinePunctuationConfig &config)
|
|
||||||
: config_(config), model_(config.model) {
|
: config_(config), model_(config.model) {
|
||||||
if (!config_.model.bpe_vocab.empty()) {
|
if (!config_.model.bpe_vocab.empty()) {
|
||||||
bpe_encoder_ = std::make_unique<ssentencepiece::Ssentencepiece>(
|
bpe_encoder_ = std::make_unique<ssentencepiece::Ssentencepiece>(
|
||||||
@@ -75,34 +75,43 @@ class OnlinePunctuationCNNBiLSTMImpl : public OnlinePunctuationImpl {
|
|||||||
int32_t n = label_len_list.size();
|
int32_t n = label_len_list.size();
|
||||||
|
|
||||||
std::array<int64_t, 2> token_ids_shape = {n, kMaxSeqLen};
|
std::array<int64_t, 2> token_ids_shape = {n, kMaxSeqLen};
|
||||||
Ort::Value token_ids = Ort::Value::CreateTensor(memory_info, tokens_list.data(), tokens_list.size(),
|
Ort::Value token_ids = Ort::Value::CreateTensor(
|
||||||
|
memory_info, tokens_list.data(), tokens_list.size(),
|
||||||
token_ids_shape.data(), token_ids_shape.size());
|
token_ids_shape.data(), token_ids_shape.size());
|
||||||
|
|
||||||
std::array<int64_t, 2> valid_ids_shape = {n, kMaxSeqLen};
|
std::array<int64_t, 2> valid_ids_shape = {n, kMaxSeqLen};
|
||||||
Ort::Value valid_ids = Ort::Value::CreateTensor(memory_info, valids_list.data(), valids_list.size(),
|
Ort::Value valid_ids = Ort::Value::CreateTensor(
|
||||||
|
memory_info, valids_list.data(), valids_list.size(),
|
||||||
valid_ids_shape.data(), valid_ids_shape.size());
|
valid_ids_shape.data(), valid_ids_shape.size());
|
||||||
|
|
||||||
std::array<int64_t, 1> label_len_shape = {n};
|
std::array<int64_t, 1> label_len_shape = {n};
|
||||||
Ort::Value label_len = Ort::Value::CreateTensor(memory_info, label_len_list.data(), label_len_list.size(),
|
Ort::Value label_len = Ort::Value::CreateTensor(
|
||||||
|
memory_info, label_len_list.data(), label_len_list.size(),
|
||||||
label_len_shape.data(), label_len_shape.size());
|
label_len_shape.data(), label_len_shape.size());
|
||||||
|
|
||||||
auto pair = model_.Forward(std::move(token_ids), std::move(valid_ids), std::move(label_len));
|
auto pair = model_.Forward(std::move(token_ids), std::move(valid_ids),
|
||||||
|
std::move(label_len));
|
||||||
|
|
||||||
std::vector<int32_t> case_pred;
|
std::vector<int32_t> case_pred;
|
||||||
std::vector<int32_t> punct_pred;
|
std::vector<int32_t> punct_pred;
|
||||||
const float* active_case_logits = pair.first.GetTensorData<float>();
|
const float *active_case_logits = pair.first.GetTensorData<float>();
|
||||||
const float* active_punct_logits = pair.second.GetTensorData<float>();
|
const float *active_punct_logits = pair.second.GetTensorData<float>();
|
||||||
std::vector<int64_t> case_logits_shape = pair.first.GetTensorTypeAndShapeInfo().GetShape();
|
std::vector<int64_t> case_logits_shape =
|
||||||
|
pair.first.GetTensorTypeAndShapeInfo().GetShape();
|
||||||
|
|
||||||
for (int32_t i = 0; i < case_logits_shape[0]; ++i) {
|
for (int32_t i = 0; i < case_logits_shape[0]; ++i) {
|
||||||
const float* p_cur_case = active_case_logits + i * meta_data.num_cases;
|
const float *p_cur_case = active_case_logits + i * meta_data.num_cases;
|
||||||
auto index_case = static_cast<int32_t>(std::distance(
|
auto index_case = static_cast<int32_t>(std::distance(
|
||||||
p_cur_case, std::max_element(p_cur_case, p_cur_case + meta_data.num_cases)));
|
p_cur_case,
|
||||||
|
std::max_element(p_cur_case, p_cur_case + meta_data.num_cases)));
|
||||||
case_pred.push_back(index_case);
|
case_pred.push_back(index_case);
|
||||||
|
|
||||||
const float* p_cur_punct = active_punct_logits + i * meta_data.num_punctuations;
|
const float *p_cur_punct =
|
||||||
|
active_punct_logits + i * meta_data.num_punctuations;
|
||||||
auto index_punct = static_cast<int32_t>(std::distance(
|
auto index_punct = static_cast<int32_t>(std::distance(
|
||||||
p_cur_punct, std::max_element(p_cur_punct, p_cur_punct + meta_data.num_punctuations)));
|
p_cur_punct,
|
||||||
|
std::max_element(p_cur_punct,
|
||||||
|
p_cur_punct + meta_data.num_punctuations)));
|
||||||
punct_pred.push_back(index_punct);
|
punct_pred.push_back(index_punct);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -112,10 +121,10 @@ class OnlinePunctuationCNNBiLSTMImpl : public OnlinePunctuationImpl {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void EncodeSentences(const std::string& text,
|
void EncodeSentences(const std::string &text,
|
||||||
std::vector<int32_t>& tokens_list,
|
std::vector<int32_t> &tokens_list, // NOLINT
|
||||||
std::vector<int32_t>& valids_list,
|
std::vector<int32_t> &valids_list, // NOLINT
|
||||||
std::vector<int32_t>& label_len_list) const {
|
std::vector<int32_t> &label_len_list) const { // NOLINT
|
||||||
std::vector<int32_t> tokens;
|
std::vector<int32_t> tokens;
|
||||||
std::vector<int32_t> valids;
|
std::vector<int32_t> valids;
|
||||||
int32_t label_len = 0;
|
int32_t label_len = 0;
|
||||||
@@ -184,9 +193,9 @@ class OnlinePunctuationCNNBiLSTMImpl : public OnlinePunctuationImpl {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string DecodeSentences(const std::string& raw_text,
|
std::string DecodeSentences(const std::string &raw_text,
|
||||||
const std::vector<int32_t>& case_pred,
|
const std::vector<int32_t> &case_pred,
|
||||||
const std::vector<int32_t>& punct_pred) const {
|
const std::vector<int32_t> &punct_pred) const {
|
||||||
std::string result_text;
|
std::string result_text;
|
||||||
std::istringstream iss(raw_text);
|
std::istringstream iss(raw_text);
|
||||||
std::vector<std::string> words;
|
std::vector<std::string> words;
|
||||||
@@ -205,7 +214,8 @@ class OnlinePunctuationCNNBiLSTMImpl : public OnlinePunctuationImpl {
|
|||||||
switch (case_pred[i]) {
|
switch (case_pred[i]) {
|
||||||
case 1: // upper
|
case 1: // upper
|
||||||
{
|
{
|
||||||
std::transform(words[i].begin(), words[i].end(), words[i].begin(), [](auto c){ return std::toupper(c); });
|
std::transform(words[i].begin(), words[i].end(), words[i].begin(),
|
||||||
|
[](auto c) { return std::toupper(c); });
|
||||||
result_text += words[i];
|
result_text += words[i];
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
@@ -217,14 +227,14 @@ class OnlinePunctuationCNNBiLSTMImpl : public OnlinePunctuationImpl {
|
|||||||
}
|
}
|
||||||
case 3: // mix case
|
case 3: // mix case
|
||||||
{
|
{
|
||||||
// TODO:
|
// TODO(frankyoujian):
|
||||||
// Need to add a map containing supported mix case words so that we can fetch the predicted word from the map
|
// Need to add a map containing supported mix case words so that we
|
||||||
// e.g. mcdonald's -> McDonald's
|
// can fetch the predicted word from the map e.g. mcdonald's ->
|
||||||
|
// McDonald's
|
||||||
result_text += words[i];
|
result_text += words[i];
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
default:
|
default: {
|
||||||
{
|
|
||||||
result_text += words[i];
|
result_text += words[i];
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -20,7 +20,9 @@ std::unique_ptr<OnlinePunctuationImpl> OnlinePunctuationImpl::Create(
|
|||||||
return std::make_unique<OnlinePunctuationCNNBiLSTMImpl>(config);
|
return std::make_unique<OnlinePunctuationCNNBiLSTMImpl>(config);
|
||||||
}
|
}
|
||||||
|
|
||||||
SHERPA_ONNX_LOGE("Please specify a punctuation model and bpe vocab! Return a null pointer");
|
SHERPA_ONNX_LOGE(
|
||||||
|
"Please specify a punctuation model and bpe vocab! Return a null "
|
||||||
|
"pointer");
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -31,7 +33,9 @@ std::unique_ptr<OnlinePunctuationImpl> OnlinePunctuationImpl::Create(
|
|||||||
return std::make_unique<OnlinePunctuationCNNBiLSTMImpl>(mgr, config);
|
return std::make_unique<OnlinePunctuationCNNBiLSTMImpl>(mgr, config);
|
||||||
}
|
}
|
||||||
|
|
||||||
SHERPA_ONNX_LOGE("Please specify a punctuation model and bpe vocab! Return a null pointer");
|
SHERPA_ONNX_LOGE(
|
||||||
|
"Please specify a punctuation model and bpe vocab! Return a null "
|
||||||
|
"pointer");
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|||||||
@@ -13,8 +13,7 @@ void OnlinePunctuationModelConfig::Register(ParseOptions *po) {
|
|||||||
po->Register("cnn-bilstm", &cnn_bilstm,
|
po->Register("cnn-bilstm", &cnn_bilstm,
|
||||||
"Path to the light-weight CNN-BiLSTM model");
|
"Path to the light-weight CNN-BiLSTM model");
|
||||||
|
|
||||||
po->Register("bpe-vocab", &bpe_vocab,
|
po->Register("bpe-vocab", &bpe_vocab, "Path to the bpe vocab file");
|
||||||
"Path to the bpe vocab file");
|
|
||||||
|
|
||||||
po->Register("num-threads", &num_threads,
|
po->Register("num-threads", &num_threads,
|
||||||
"Number of threads to run the neural network");
|
"Number of threads to run the neural network");
|
||||||
@@ -33,8 +32,7 @@ bool OnlinePunctuationModelConfig::Validate() const {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (!FileExists(cnn_bilstm)) {
|
if (!FileExists(cnn_bilstm)) {
|
||||||
SHERPA_ONNX_LOGE("--cnn-bilstm '%s' does not exist",
|
SHERPA_ONNX_LOGE("--cnn-bilstm '%s' does not exist", cnn_bilstm.c_str());
|
||||||
cnn_bilstm.c_str());
|
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -44,8 +42,7 @@ bool OnlinePunctuationModelConfig::Validate() const {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (!FileExists(bpe_vocab)) {
|
if (!FileExists(bpe_vocab)) {
|
||||||
SHERPA_ONNX_LOGE("--bpe-vocab '%s' does not exist",
|
SHERPA_ONNX_LOGE("--bpe-vocab '%s' does not exist", bpe_vocab.c_str());
|
||||||
bpe_vocab.c_str());
|
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -14,9 +14,7 @@
|
|||||||
|
|
||||||
namespace sherpa_onnx {
|
namespace sherpa_onnx {
|
||||||
|
|
||||||
void OnlinePunctuationConfig::Register(ParseOptions *po) {
|
void OnlinePunctuationConfig::Register(ParseOptions *po) { model.Register(po); }
|
||||||
model.Register(po);
|
|
||||||
}
|
|
||||||
|
|
||||||
bool OnlinePunctuationConfig::Validate() const {
|
bool OnlinePunctuationConfig::Validate() const {
|
||||||
if (!model.Validate()) {
|
if (!model.Validate()) {
|
||||||
@@ -46,7 +44,8 @@ OnlinePunctuation::OnlinePunctuation(AAssetManager *mgr,
|
|||||||
|
|
||||||
OnlinePunctuation::~OnlinePunctuation() = default;
|
OnlinePunctuation::~OnlinePunctuation() = default;
|
||||||
|
|
||||||
std::string OnlinePunctuation::AddPunctuationWithCase(const std::string &text) const {
|
std::string OnlinePunctuation::AddPunctuationWithCase(
|
||||||
|
const std::string &text) const {
|
||||||
return impl_->AddPunctuationWithCase(text);
|
return impl_->AddPunctuationWithCase(text);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -40,8 +40,7 @@ class OnlinePunctuation {
|
|||||||
explicit OnlinePunctuation(const OnlinePunctuationConfig &config);
|
explicit OnlinePunctuation(const OnlinePunctuationConfig &config);
|
||||||
|
|
||||||
#if __ANDROID_API__ >= 9
|
#if __ANDROID_API__ >= 9
|
||||||
OnlinePunctuation(AAssetManager *mgr,
|
OnlinePunctuation(AAssetManager *mgr, const OnlinePunctuationConfig &config);
|
||||||
const OnlinePunctuationConfig &config);
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
~OnlinePunctuation();
|
~OnlinePunctuation();
|
||||||
|
|||||||
@@ -3,9 +3,9 @@
|
|||||||
// Copyright (c) 2024 Jian You (jianyou@cisco.com, Cisco Systems)
|
// Copyright (c) 2024 Jian You (jianyou@cisco.com, Cisco Systems)
|
||||||
|
|
||||||
#include <stdio.h>
|
#include <stdio.h>
|
||||||
#include <iostream>
|
|
||||||
|
|
||||||
#include <chrono> // NOLINT
|
#include <chrono> // NOLINT
|
||||||
|
#include <iostream>
|
||||||
|
|
||||||
#include "sherpa-onnx/csrc/online-punctuation.h"
|
#include "sherpa-onnx/csrc/online-punctuation.h"
|
||||||
#include "sherpa-onnx/csrc/parse-options.h"
|
#include "sherpa-onnx/csrc/parse-options.h"
|
||||||
|
|||||||
Reference in New Issue
Block a user