Add C++ support for streaming NeMo CTC models. (#857)

This commit is contained in:
Fangjun Kuang
2024-05-10 16:26:43 +08:00
committed by GitHub
parent 1eb60e8711
commit 46e4e5b7ac
22 changed files with 782 additions and 41 deletions

View File

@@ -61,6 +61,8 @@ set(sources
online-lm.cc
online-lstm-transducer-model.cc
online-model-config.cc
online-nemo-ctc-model-config.cc
online-nemo-ctc-model.cc
online-paraformer-model-config.cc
online-paraformer-model.cc
online-recognizer-impl.cc

View File

@@ -4,11 +4,12 @@
#ifndef SHERPA_ONNX_CSRC_OFFLINE_PUNCTUATION_CT_TRANSFORMER_IMPL_H_
#define SHERPA_ONNX_CSRC_OFFLINE_PUNCTUATION_CT_TRANSFORMER_IMPL_H_
#include <math.h>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include <math.h>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
@@ -61,7 +62,9 @@ class OfflinePunctuationCtTransformerImpl : public OfflinePunctuationImpl {
int32_t segment_size = 20;
int32_t max_len = 200;
int32_t num_segments = ceil(((float)token_ids.size() + segment_size - 1) / segment_size);
int32_t num_segments =
ceil((static_cast<float>(token_ids.size()) + segment_size - 1) /
segment_size);
std::vector<int32_t> punctuations;
int32_t last = -1;

View File

@@ -10,6 +10,7 @@
#include <string>
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/online-nemo-ctc-model.h"
#include "sherpa-onnx/csrc/online-wenet-ctc-model.h"
#include "sherpa-onnx/csrc/online-zipformer2-ctc-model.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
@@ -22,6 +23,8 @@ std::unique_ptr<OnlineCtcModel> OnlineCtcModel::Create(
return std::make_unique<OnlineWenetCtcModel>(config);
} else if (!config.zipformer2_ctc.model.empty()) {
return std::make_unique<OnlineZipformer2CtcModel>(config);
} else if (!config.nemo_ctc.model.empty()) {
return std::make_unique<OnlineNeMoCtcModel>(config);
} else {
SHERPA_ONNX_LOGE("Please specify a CTC model");
exit(-1);
@@ -36,6 +39,8 @@ std::unique_ptr<OnlineCtcModel> OnlineCtcModel::Create(
return std::make_unique<OnlineWenetCtcModel>(mgr, config);
} else if (!config.zipformer2_ctc.model.empty()) {
return std::make_unique<OnlineZipformer2CtcModel>(mgr, config);
} else if (!config.nemo_ctc.model.empty()) {
return std::make_unique<OnlineNeMoCtcModel>(mgr, config);
} else {
SHERPA_ONNX_LOGE("Please specify a CTC model");
exit(-1);

View File

@@ -15,6 +15,7 @@ void OnlineModelConfig::Register(ParseOptions *po) {
paraformer.Register(po);
wenet_ctc.Register(po);
zipformer2_ctc.Register(po);
nemo_ctc.Register(po);
po->Register("tokens", &tokens, "Path to tokens.txt");
@@ -31,11 +32,11 @@ void OnlineModelConfig::Register(ParseOptions *po) {
po->Register("provider", &provider,
"Specify a provider to use: cpu, cuda, coreml");
po->Register(
"model-type", &model_type,
"Specify it to reduce model initialization time. "
"Valid values are: conformer, lstm, zipformer, zipformer2, wenet_ctc"
"All other values lead to loading the model twice.");
po->Register("model-type", &model_type,
"Specify it to reduce model initialization time. "
"Valid values are: conformer, lstm, zipformer, zipformer2, "
"wenet_ctc, nemo_ctc. "
"All other values lead to loading the model twice.");
}
bool OnlineModelConfig::Validate() const {
@@ -61,6 +62,10 @@ bool OnlineModelConfig::Validate() const {
return zipformer2_ctc.Validate();
}
if (!nemo_ctc.model.empty()) {
return nemo_ctc.Validate();
}
return transducer.Validate();
}
@@ -72,6 +77,7 @@ std::string OnlineModelConfig::ToString() const {
os << "paraformer=" << paraformer.ToString() << ", ";
os << "wenet_ctc=" << wenet_ctc.ToString() << ", ";
os << "zipformer2_ctc=" << zipformer2_ctc.ToString() << ", ";
os << "nemo_ctc=" << nemo_ctc.ToString() << ", ";
os << "tokens=\"" << tokens << "\", ";
os << "num_threads=" << num_threads << ", ";
os << "warm_up=" << warm_up << ", ";

View File

@@ -6,6 +6,7 @@
#include <string>
#include "sherpa-onnx/csrc/online-nemo-ctc-model-config.h"
#include "sherpa-onnx/csrc/online-paraformer-model-config.h"
#include "sherpa-onnx/csrc/online-transducer-model-config.h"
#include "sherpa-onnx/csrc/online-wenet-ctc-model-config.h"
@@ -18,6 +19,7 @@ struct OnlineModelConfig {
OnlineParaformerModelConfig paraformer;
OnlineWenetCtcModelConfig wenet_ctc;
OnlineZipformer2CtcModelConfig zipformer2_ctc;
OnlineNeMoCtcModelConfig nemo_ctc;
std::string tokens;
int32_t num_threads = 1;
int32_t warm_up = 0;
@@ -30,6 +32,7 @@ struct OnlineModelConfig {
// - zipformer, zipformer transducer from icefall
// - zipformer2, zipformer2 transducer or CTC from icefall
// - wenet_ctc, wenet CTC model
// - nemo_ctc, NeMo CTC model
//
// All other values are invalid and lead to loading the model twice.
std::string model_type;
@@ -39,6 +42,7 @@ struct OnlineModelConfig {
const OnlineParaformerModelConfig &paraformer,
const OnlineWenetCtcModelConfig &wenet_ctc,
const OnlineZipformer2CtcModelConfig &zipformer2_ctc,
const OnlineNeMoCtcModelConfig &nemo_ctc,
const std::string &tokens, int32_t num_threads,
int32_t warm_up, bool debug, const std::string &provider,
const std::string &model_type)
@@ -46,6 +50,7 @@ struct OnlineModelConfig {
paraformer(paraformer),
wenet_ctc(wenet_ctc),
zipformer2_ctc(zipformer2_ctc),
nemo_ctc(nemo_ctc),
tokens(tokens),
num_threads(num_threads),
warm_up(warm_up),

View File

@@ -0,0 +1,36 @@
// sherpa-onnx/csrc/online-nemo-ctc-model-config.cc
//
// Copyright (c) 2024 Xiaomi Corporation
#include "sherpa-onnx/csrc/online-nemo-ctc-model-config.h"
#include "sherpa-onnx/csrc/file-utils.h"
#include "sherpa-onnx/csrc/macros.h"
namespace sherpa_onnx {
void OnlineNeMoCtcModelConfig::Register(ParseOptions *po) {
po->Register("nemo-ctc-model", &model,
"Path to CTC model.onnx from NeMo. Please see "
"https://github.com/k2-fsa/sherpa-onnx/pull/843");
}
bool OnlineNeMoCtcModelConfig::Validate() const {
if (!FileExists(model)) {
SHERPA_ONNX_LOGE("NeMo CTC model '%s' does not exist", model.c_str());
return false;
}
return true;
}
std::string OnlineNeMoCtcModelConfig::ToString() const {
std::ostringstream os;
os << "OnlineNeMoCtcModelConfig(";
os << "model=\"" << model << "\")";
return os.str();
}
} // namespace sherpa_onnx

View File

@@ -0,0 +1,28 @@
// sherpa-onnx/csrc/online-nemo-ctc-model-config.h
//
// Copyright (c) 2024 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_ONLINE_NEMO_CTC_MODEL_CONFIG_H_
#define SHERPA_ONNX_CSRC_ONLINE_NEMO_CTC_MODEL_CONFIG_H_
#include <string>
#include "sherpa-onnx/csrc/parse-options.h"
namespace sherpa_onnx {
struct OnlineNeMoCtcModelConfig {
std::string model;
OnlineNeMoCtcModelConfig() = default;
explicit OnlineNeMoCtcModelConfig(const std::string &model) : model(model) {}
void Register(ParseOptions *po);
bool Validate() const;
std::string ToString() const;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_ONLINE_NEMO_CTC_MODEL_CONFIG_H_

View File

@@ -0,0 +1,324 @@
// sherpa-onnx/csrc/online-nemo-ctc-model.cc
//
// Copyright (c) 2023 Xiaomi Corporation
#include "sherpa-onnx/csrc/online-nemo-ctc-model.h"
#include <algorithm>
#include <cmath>
#include <string>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#include "sherpa-onnx/csrc/cat.h"
#include "sherpa-onnx/csrc/macros.h"
#include "sherpa-onnx/csrc/onnx-utils.h"
#include "sherpa-onnx/csrc/session.h"
#include "sherpa-onnx/csrc/text-utils.h"
#include "sherpa-onnx/csrc/transpose.h"
#include "sherpa-onnx/csrc/unbind.h"
namespace sherpa_onnx {
class OnlineNeMoCtcModel::Impl {
public:
explicit Impl(const OnlineModelConfig &config)
: config_(config),
env_(ORT_LOGGING_LEVEL_ERROR),
sess_opts_(GetSessionOptions(config)),
allocator_{} {
{
auto buf = ReadFile(config.nemo_ctc.model);
Init(buf.data(), buf.size());
}
}
#if __ANDROID_API__ >= 9
Impl(AAssetManager *mgr, const OnlineModelConfig &config)
: config_(config),
env_(ORT_LOGGING_LEVEL_WARNING),
sess_opts_(GetSessionOptions(config)),
allocator_{} {
{
auto buf = ReadFile(mgr, config.nemo_ctc.model);
Init(buf.data(), buf.size());
}
}
#endif
std::vector<Ort::Value> Forward(Ort::Value x,
std::vector<Ort::Value> states) {
Ort::Value &cache_last_channel = states[0];
Ort::Value &cache_last_time = states[1];
Ort::Value &cache_last_channel_len = states[2];
int32_t batch_size = x.GetTensorTypeAndShapeInfo().GetShape()[0];
std::array<int64_t, 1> length_shape{batch_size};
Ort::Value length = Ort::Value::CreateTensor<int64_t>(
allocator_, length_shape.data(), length_shape.size());
int64_t *p_length = length.GetTensorMutableData<int64_t>();
std::fill(p_length, p_length + batch_size, ChunkLength());
// (B, T, C) -> (B, C, T)
x = Transpose12(allocator_, &x);
std::array<Ort::Value, 5> inputs = {
std::move(x), View(&length), std::move(cache_last_channel),
std::move(cache_last_time), std::move(cache_last_channel_len)};
auto out =
sess_->Run({}, input_names_ptr_.data(), inputs.data(), inputs.size(),
output_names_ptr_.data(), output_names_ptr_.size());
// out[0]: logit
// out[1] logit_length
// out[2:] states_next
//
// we need to remove out[1]
std::vector<Ort::Value> ans;
ans.reserve(out.size() - 1);
for (int32_t i = 0; i != out.size(); ++i) {
if (i == 1) {
continue;
}
ans.push_back(std::move(out[i]));
}
return ans;
}
int32_t VocabSize() const { return vocab_size_; }
int32_t ChunkLength() const { return window_size_; }
int32_t ChunkShift() const { return chunk_shift_; }
OrtAllocator *Allocator() const { return allocator_; }
// Return a vector containing 3 tensors
// - cache_last_channel
// - cache_last_time_
// - cache_last_channel_len
std::vector<Ort::Value> GetInitStates() {
std::vector<Ort::Value> ans;
ans.reserve(3);
ans.push_back(View(&cache_last_channel_));
ans.push_back(View(&cache_last_time_));
ans.push_back(View(&cache_last_channel_len_));
return ans;
}
std::vector<Ort::Value> StackStates(
std::vector<std::vector<Ort::Value>> states) const {
int32_t batch_size = static_cast<int32_t>(states.size());
if (batch_size == 1) {
return std::move(states[0]);
}
std::vector<Ort::Value> ans;
// stack cache_last_channel
std::vector<const Ort::Value *> buf(batch_size);
// there are 3 states to be stacked
for (int32_t i = 0; i != 3; ++i) {
buf.clear();
buf.reserve(batch_size);
for (int32_t b = 0; b != batch_size; ++b) {
assert(states[b].size() == 3);
buf.push_back(&states[b][i]);
}
Ort::Value c{nullptr};
if (i == 2) {
c = Cat<int64_t>(allocator_, buf, 0);
} else {
c = Cat(allocator_, buf, 0);
}
ans.push_back(std::move(c));
}
return ans;
}
std::vector<std::vector<Ort::Value>> UnStackStates(
std::vector<Ort::Value> states) const {
assert(states.size() == 3);
std::vector<std::vector<Ort::Value>> ans;
auto shape = states[0].GetTensorTypeAndShapeInfo().GetShape();
int32_t batch_size = shape[0];
ans.resize(batch_size);
if (batch_size == 1) {
ans[0] = std::move(states);
return ans;
}
for (int32_t i = 0; i != 3; ++i) {
std::vector<Ort::Value> v;
if (i == 2) {
v = Unbind<int64_t>(allocator_, &states[i], 0);
} else {
v = Unbind(allocator_, &states[i], 0);
}
assert(v.size() == batch_size);
for (int32_t b = 0; b != batch_size; ++b) {
ans[b].push_back(std::move(v[b]));
}
}
return ans;
}
private:
void Init(void *model_data, size_t model_data_length) {
sess_ = std::make_unique<Ort::Session>(env_, model_data, model_data_length,
sess_opts_);
GetInputNames(sess_.get(), &input_names_, &input_names_ptr_);
GetOutputNames(sess_.get(), &output_names_, &output_names_ptr_);
// get meta data
Ort::ModelMetadata meta_data = sess_->GetModelMetadata();
if (config_.debug) {
std::ostringstream os;
PrintModelMetadata(os, meta_data);
SHERPA_ONNX_LOGE("%s\n", os.str().c_str());
}
Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
SHERPA_ONNX_READ_META_DATA(window_size_, "window_size");
SHERPA_ONNX_READ_META_DATA(chunk_shift_, "chunk_shift");
SHERPA_ONNX_READ_META_DATA(subsampling_factor_, "subsampling_factor");
SHERPA_ONNX_READ_META_DATA(vocab_size_, "vocab_size");
SHERPA_ONNX_READ_META_DATA(cache_last_channel_dim1_,
"cache_last_channel_dim1");
SHERPA_ONNX_READ_META_DATA(cache_last_channel_dim2_,
"cache_last_channel_dim2");
SHERPA_ONNX_READ_META_DATA(cache_last_channel_dim3_,
"cache_last_channel_dim3");
SHERPA_ONNX_READ_META_DATA(cache_last_time_dim1_, "cache_last_time_dim1");
SHERPA_ONNX_READ_META_DATA(cache_last_time_dim2_, "cache_last_time_dim2");
SHERPA_ONNX_READ_META_DATA(cache_last_time_dim3_, "cache_last_time_dim3");
// need to increase by 1 since the blank token is not included in computing
// vocab_size in NeMo.
vocab_size_ += 1;
InitStates();
}
void InitStates() {
std::array<int64_t, 4> cache_last_channel_shape{1, cache_last_channel_dim1_,
cache_last_channel_dim2_,
cache_last_channel_dim3_};
cache_last_channel_ = Ort::Value::CreateTensor<float>(
allocator_, cache_last_channel_shape.data(),
cache_last_channel_shape.size());
Fill<float>(&cache_last_channel_, 0);
std::array<int64_t, 4> cache_last_time_shape{
1, cache_last_time_dim1_, cache_last_time_dim2_, cache_last_time_dim3_};
cache_last_time_ = Ort::Value::CreateTensor<float>(
allocator_, cache_last_time_shape.data(), cache_last_time_shape.size());
Fill<float>(&cache_last_time_, 0);
int64_t shape = 1;
cache_last_channel_len_ =
Ort::Value::CreateTensor<int64_t>(allocator_, &shape, 1);
cache_last_channel_len_.GetTensorMutableData<int64_t>()[0] = 0;
}
private:
OnlineModelConfig config_;
Ort::Env env_;
Ort::SessionOptions sess_opts_;
Ort::AllocatorWithDefaultOptions allocator_;
std::unique_ptr<Ort::Session> sess_;
std::vector<std::string> input_names_;
std::vector<const char *> input_names_ptr_;
std::vector<std::string> output_names_;
std::vector<const char *> output_names_ptr_;
int32_t window_size_;
int32_t chunk_shift_;
int32_t subsampling_factor_;
int32_t vocab_size_;
int32_t cache_last_channel_dim1_;
int32_t cache_last_channel_dim2_;
int32_t cache_last_channel_dim3_;
int32_t cache_last_time_dim1_;
int32_t cache_last_time_dim2_;
int32_t cache_last_time_dim3_;
Ort::Value cache_last_channel_{nullptr};
Ort::Value cache_last_time_{nullptr};
Ort::Value cache_last_channel_len_{nullptr};
};
OnlineNeMoCtcModel::OnlineNeMoCtcModel(const OnlineModelConfig &config)
: impl_(std::make_unique<Impl>(config)) {}
#if __ANDROID_API__ >= 9
OnlineNeMoCtcModel::OnlineNeMoCtcModel(AAssetManager *mgr,
const OnlineModelConfig &config)
: impl_(std::make_unique<Impl>(mgr, config)) {}
#endif
OnlineNeMoCtcModel::~OnlineNeMoCtcModel() = default;
std::vector<Ort::Value> OnlineNeMoCtcModel::Forward(
Ort::Value x, std::vector<Ort::Value> states) const {
return impl_->Forward(std::move(x), std::move(states));
}
int32_t OnlineNeMoCtcModel::VocabSize() const { return impl_->VocabSize(); }
int32_t OnlineNeMoCtcModel::ChunkLength() const { return impl_->ChunkLength(); }
int32_t OnlineNeMoCtcModel::ChunkShift() const { return impl_->ChunkShift(); }
OrtAllocator *OnlineNeMoCtcModel::Allocator() const {
return impl_->Allocator();
}
std::vector<Ort::Value> OnlineNeMoCtcModel::GetInitStates() const {
return impl_->GetInitStates();
}
std::vector<Ort::Value> OnlineNeMoCtcModel::StackStates(
std::vector<std::vector<Ort::Value>> states) const {
return impl_->StackStates(std::move(states));
}
std::vector<std::vector<Ort::Value>> OnlineNeMoCtcModel::UnStackStates(
std::vector<Ort::Value> states) const {
return impl_->UnStackStates(std::move(states));
}
} // namespace sherpa_onnx

View File

@@ -0,0 +1,81 @@
// sherpa-onnx/csrc/online-nemo-ctc-model.h
//
// Copyright (c) 2024 Xiaomi Corporation
#ifndef SHERPA_ONNX_CSRC_ONLINE_NEMO_CTC_MODEL_H_
#define SHERPA_ONNX_CSRC_ONLINE_NEMO_CTC_MODEL_H_
#include <memory>
#include <utility>
#include <vector>
#if __ANDROID_API__ >= 9
#include "android/asset_manager.h"
#include "android/asset_manager_jni.h"
#endif
#include "onnxruntime_cxx_api.h" // NOLINT
#include "sherpa-onnx/csrc/online-ctc-model.h"
#include "sherpa-onnx/csrc/online-model-config.h"
namespace sherpa_onnx {
class OnlineNeMoCtcModel : public OnlineCtcModel {
public:
explicit OnlineNeMoCtcModel(const OnlineModelConfig &config);
#if __ANDROID_API__ >= 9
OnlineNeMoCtcModel(AAssetManager *mgr, const OnlineModelConfig &config);
#endif
~OnlineNeMoCtcModel() override;
// A list of 3 tensors:
// - cache_last_channel
// - cache_last_time
// - cache_last_channel_len
std::vector<Ort::Value> GetInitStates() const override;
std::vector<Ort::Value> StackStates(
std::vector<std::vector<Ort::Value>> states) const override;
std::vector<std::vector<Ort::Value>> UnStackStates(
std::vector<Ort::Value> states) const override;
/**
*
* @param x A 3-D tensor of shape (N, T, C). N has to be 1.
* @param states It is from GetInitStates() or returned from this method.
*
* @return Return a list of tensors
* - ans[0] contains log_probs, of shape (N, T, C)
* - ans[1:] contains next_states
*/
std::vector<Ort::Value> Forward(
Ort::Value x, std::vector<Ort::Value> states) const override;
/** Return the vocabulary size of the model
*/
int32_t VocabSize() const override;
/** Return an allocator for allocating memory
*/
OrtAllocator *Allocator() const override;
// The model accepts this number of frames before subsampling as input
int32_t ChunkLength() const override;
// Similar to frame_shift in feature extractor, after processing
// ChunkLength() frames, we advance by ChunkShift() frames
// before we process the next chunk.
int32_t ChunkShift() const override;
bool SupportBatchProcessing() const override { return true; }
private:
class Impl;
std::unique_ptr<Impl> impl_;
};
} // namespace sherpa_onnx
#endif // SHERPA_ONNX_CSRC_ONLINE_NEMO_CTC_MODEL_H_

View File

@@ -21,7 +21,8 @@ std::unique_ptr<OnlineRecognizerImpl> OnlineRecognizerImpl::Create(
}
if (!config.model_config.wenet_ctc.model.empty() ||
!config.model_config.zipformer2_ctc.model.empty()) {
!config.model_config.zipformer2_ctc.model.empty() ||
!config.model_config.nemo_ctc.model.empty()) {
return std::make_unique<OnlineRecognizerCtcImpl>(config);
}
@@ -41,7 +42,8 @@ std::unique_ptr<OnlineRecognizerImpl> OnlineRecognizerImpl::Create(
}
if (!config.model_config.wenet_ctc.model.empty() ||
!config.model_config.zipformer2_ctc.model.empty()) {
!config.model_config.zipformer2_ctc.model.empty() ||
!config.model_config.nemo_ctc.model.empty()) {
return std::make_unique<OnlineRecognizerCtcImpl>(mgr, config);
}