Stack and streaming conformer support (#141)
* added csrc/stack.cc * stack: added checks * added copyright info * passed cpp style checks * formatted code * added some support for streaming conformer model support (not verified) * code lint * made more progress with streaming conformer support (not working yet) * passed style check * changes as suggested by @csukuangfj * added some debug info * fixed style check * Use Cat to replace Stack * remove debug statements --------- Co-authored-by: Jingzhao Ou (jou2019) <jou2019@cisco.com> Co-authored-by: Fangjun Kuang <csukuangfj@gmail.com>
This commit is contained in:
@@ -34,6 +34,7 @@ set(sources
|
||||
offline-transducer-model-config.cc
|
||||
offline-transducer-model.cc
|
||||
offline-transducer-modified-beam-search-decoder.cc
|
||||
online-conformer-transducer-model.cc
|
||||
online-lm.cc
|
||||
online-lm-config.cc
|
||||
online-lstm-transducer-model.cc
|
||||
@@ -52,6 +53,7 @@ set(sources
|
||||
parse-options.cc
|
||||
resample.cc
|
||||
slice.cc
|
||||
stack.cc
|
||||
symbol-table.cc
|
||||
text-utils.cc
|
||||
transpose.cc
|
||||
@@ -241,6 +243,7 @@ if(SHERPA_ONNX_ENABLE_TESTS)
|
||||
packed-sequence-test.cc
|
||||
pad-sequence-test.cc
|
||||
slice-test.cc
|
||||
stack-test.cc
|
||||
transpose-test.cc
|
||||
unbind-test.cc
|
||||
)
|
||||
|
||||
279
sherpa-onnx/csrc/online-conformer-transducer-model.cc
Normal file
279
sherpa-onnx/csrc/online-conformer-transducer-model.cc
Normal file
@@ -0,0 +1,279 @@
|
||||
// sherpa-onnx/csrc/online-conformer-transducer-model.cc
|
||||
//
|
||||
// Copyright (c) 2023 Jingzhao Ou (jingzhao.ou@gmail.com)
|
||||
|
||||
#include "sherpa-onnx/csrc/online-conformer-transducer-model.h"
|
||||
|
||||
#include <assert.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <sstream>
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
#include "android/asset_manager.h"
|
||||
#include "android/asset_manager_jni.h"
|
||||
#endif
|
||||
|
||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||
#include "sherpa-onnx/csrc/cat.h"
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/online-transducer-decoder.h"
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
#include "sherpa-onnx/csrc/text-utils.h"
|
||||
#include "sherpa-onnx/csrc/unbind.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
OnlineConformerTransducerModel::OnlineConformerTransducerModel(
|
||||
const OnlineTransducerModelConfig &config)
|
||||
: env_(ORT_LOGGING_LEVEL_WARNING),
|
||||
config_(config),
|
||||
sess_opts_{},
|
||||
allocator_{} {
|
||||
sess_opts_.SetIntraOpNumThreads(config.num_threads);
|
||||
sess_opts_.SetInterOpNumThreads(config.num_threads);
|
||||
|
||||
{
|
||||
auto buf = ReadFile(config.encoder_filename);
|
||||
InitEncoder(buf.data(), buf.size());
|
||||
}
|
||||
|
||||
{
|
||||
auto buf = ReadFile(config.decoder_filename);
|
||||
InitDecoder(buf.data(), buf.size());
|
||||
}
|
||||
|
||||
{
|
||||
auto buf = ReadFile(config.joiner_filename);
|
||||
InitJoiner(buf.data(), buf.size());
|
||||
}
|
||||
}
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
OnlineConformerTransducerModel::OnlineConformerTransducerModel(
|
||||
AAssetManager *mgr, const OnlineTransducerModelConfig &config)
|
||||
: env_(ORT_LOGGING_LEVEL_WARNING),
|
||||
config_(config),
|
||||
sess_opts_{},
|
||||
allocator_{} {
|
||||
sess_opts_.SetIntraOpNumThreads(config.num_threads);
|
||||
sess_opts_.SetInterOpNumThreads(config.num_threads);
|
||||
|
||||
{
|
||||
auto buf = ReadFile(mgr, config.encoder_filename);
|
||||
InitEncoder(buf.data(), buf.size());
|
||||
}
|
||||
|
||||
{
|
||||
auto buf = ReadFile(mgr, config.decoder_filename);
|
||||
InitDecoder(buf.data(), buf.size());
|
||||
}
|
||||
|
||||
{
|
||||
auto buf = ReadFile(mgr, config.joiner_filename);
|
||||
InitJoiner(buf.data(), buf.size());
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
void OnlineConformerTransducerModel::InitEncoder(void *model_data,
|
||||
size_t model_data_length) {
|
||||
encoder_sess_ = std::make_unique<Ort::Session>(env_, model_data,
|
||||
model_data_length, sess_opts_);
|
||||
|
||||
GetInputNames(encoder_sess_.get(), &encoder_input_names_,
|
||||
&encoder_input_names_ptr_);
|
||||
|
||||
GetOutputNames(encoder_sess_.get(), &encoder_output_names_,
|
||||
&encoder_output_names_ptr_);
|
||||
|
||||
// get meta data
|
||||
Ort::ModelMetadata meta_data = encoder_sess_->GetModelMetadata();
|
||||
if (config_.debug) {
|
||||
std::ostringstream os;
|
||||
os << "---encoder---\n";
|
||||
PrintModelMetadata(os, meta_data);
|
||||
SHERPA_ONNX_LOGE("%s", os.str().c_str());
|
||||
}
|
||||
|
||||
Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
|
||||
SHERPA_ONNX_READ_META_DATA(num_encoder_layers_, "num_encoder_layers");
|
||||
SHERPA_ONNX_READ_META_DATA(T_, "T");
|
||||
SHERPA_ONNX_READ_META_DATA(decode_chunk_len_, "decode_chunk_len");
|
||||
SHERPA_ONNX_READ_META_DATA(left_context_, "left_context");
|
||||
SHERPA_ONNX_READ_META_DATA(encoder_dim_, "encoder_dim");
|
||||
SHERPA_ONNX_READ_META_DATA(pad_length_, "pad_length");
|
||||
SHERPA_ONNX_READ_META_DATA(cnn_module_kernel_, "cnn_module_kernel");
|
||||
}
|
||||
|
||||
void OnlineConformerTransducerModel::InitDecoder(void *model_data,
|
||||
size_t model_data_length) {
|
||||
decoder_sess_ = std::make_unique<Ort::Session>(env_, model_data,
|
||||
model_data_length, sess_opts_);
|
||||
|
||||
GetInputNames(decoder_sess_.get(), &decoder_input_names_,
|
||||
&decoder_input_names_ptr_);
|
||||
|
||||
GetOutputNames(decoder_sess_.get(), &decoder_output_names_,
|
||||
&decoder_output_names_ptr_);
|
||||
|
||||
// get meta data
|
||||
Ort::ModelMetadata meta_data = decoder_sess_->GetModelMetadata();
|
||||
if (config_.debug) {
|
||||
std::ostringstream os;
|
||||
os << "---decoder---\n";
|
||||
PrintModelMetadata(os, meta_data);
|
||||
SHERPA_ONNX_LOGE("%s", os.str().c_str());
|
||||
}
|
||||
|
||||
Ort::AllocatorWithDefaultOptions allocator; // used in the macro below
|
||||
SHERPA_ONNX_READ_META_DATA(vocab_size_, "vocab_size");
|
||||
SHERPA_ONNX_READ_META_DATA(context_size_, "context_size");
|
||||
}
|
||||
|
||||
void OnlineConformerTransducerModel::InitJoiner(void *model_data,
|
||||
size_t model_data_length) {
|
||||
joiner_sess_ = std::make_unique<Ort::Session>(env_, model_data,
|
||||
model_data_length, sess_opts_);
|
||||
|
||||
GetInputNames(joiner_sess_.get(), &joiner_input_names_,
|
||||
&joiner_input_names_ptr_);
|
||||
|
||||
GetOutputNames(joiner_sess_.get(), &joiner_output_names_,
|
||||
&joiner_output_names_ptr_);
|
||||
|
||||
// get meta data
|
||||
Ort::ModelMetadata meta_data = joiner_sess_->GetModelMetadata();
|
||||
if (config_.debug) {
|
||||
std::ostringstream os;
|
||||
os << "---joiner---\n";
|
||||
PrintModelMetadata(os, meta_data);
|
||||
SHERPA_ONNX_LOGE("%s", os.str().c_str());
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<Ort::Value> OnlineConformerTransducerModel::StackStates(
|
||||
const std::vector<std::vector<Ort::Value>> &states) const {
|
||||
int32_t batch_size = static_cast<int32_t>(states.size());
|
||||
|
||||
std::vector<const Ort::Value *> attn_vec(batch_size);
|
||||
std::vector<const Ort::Value *> conv_vec(batch_size);
|
||||
|
||||
for (int32_t i = 0; i != batch_size; ++i) {
|
||||
assert(states[i].size() == 2);
|
||||
attn_vec[i] = &states[i][0];
|
||||
conv_vec[i] = &states[i][1];
|
||||
}
|
||||
|
||||
Ort::Value attn = Cat(allocator_, attn_vec, 2);
|
||||
Ort::Value conv = Cat(allocator_, conv_vec, 2);
|
||||
|
||||
std::vector<Ort::Value> ans;
|
||||
ans.reserve(2);
|
||||
ans.push_back(std::move(attn));
|
||||
ans.push_back(std::move(conv));
|
||||
|
||||
return ans;
|
||||
}
|
||||
|
||||
std::vector<std::vector<Ort::Value>>
|
||||
OnlineConformerTransducerModel::UnStackStates(
|
||||
const std::vector<Ort::Value> &states) const {
|
||||
const int32_t batch_size =
|
||||
states[0].GetTensorTypeAndShapeInfo().GetShape()[2];
|
||||
assert(states.size() == 2);
|
||||
|
||||
std::vector<std::vector<Ort::Value>> ans(batch_size);
|
||||
|
||||
std::vector<Ort::Value> attn_vec = Unbind(allocator_, &states[0], 2);
|
||||
std::vector<Ort::Value> conv_vec = Unbind(allocator_, &states[1], 2);
|
||||
|
||||
assert(attn_vec.size() == batch_size);
|
||||
assert(conv_vec.size() == batch_size);
|
||||
|
||||
for (int32_t i = 0; i != batch_size; ++i) {
|
||||
ans[i].push_back(std::move(attn_vec[i]));
|
||||
ans[i].push_back(std::move(conv_vec[i]));
|
||||
}
|
||||
|
||||
return ans;
|
||||
}
|
||||
|
||||
std::vector<Ort::Value> OnlineConformerTransducerModel::GetEncoderInitStates() {
|
||||
// Please see
|
||||
// https://github.com/k2-fsa/icefall/blob/86b0db6eb9c84d9bc90a71d92774fe2a7f73e6ab/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py#L203
|
||||
// for details
|
||||
constexpr int32_t kBatchSize = 1;
|
||||
std::array<int64_t, 4> h_shape{
|
||||
num_encoder_layers_, left_context_, kBatchSize, encoder_dim_};
|
||||
Ort::Value h = Ort::Value::CreateTensor<float>(allocator_, h_shape.data(),
|
||||
h_shape.size());
|
||||
|
||||
Fill<float>(&h, 0);
|
||||
|
||||
std::array<int64_t, 4> c_shape{num_encoder_layers_, cnn_module_kernel_ - 1,
|
||||
kBatchSize, encoder_dim_};
|
||||
|
||||
Ort::Value c = Ort::Value::CreateTensor<float>(allocator_, c_shape.data(),
|
||||
c_shape.size());
|
||||
|
||||
Fill<float>(&c, 0);
|
||||
|
||||
std::vector<Ort::Value> states;
|
||||
|
||||
states.reserve(2);
|
||||
states.push_back(std::move(h));
|
||||
states.push_back(std::move(c));
|
||||
|
||||
return states;
|
||||
}
|
||||
|
||||
std::pair<Ort::Value, std::vector<Ort::Value>>
|
||||
OnlineConformerTransducerModel::RunEncoder(Ort::Value features,
|
||||
std::vector<Ort::Value> states,
|
||||
Ort::Value processed_frames) {
|
||||
std::array<Ort::Value, 4> encoder_inputs = {
|
||||
std::move(features),
|
||||
std::move(states[0]),
|
||||
std::move(states[1]),
|
||||
std::move(processed_frames)};
|
||||
|
||||
auto encoder_out = encoder_sess_->Run(
|
||||
{}, encoder_input_names_ptr_.data(), encoder_inputs.data(),
|
||||
encoder_inputs.size(), encoder_output_names_ptr_.data(),
|
||||
encoder_output_names_ptr_.size());
|
||||
|
||||
std::vector<Ort::Value> next_states;
|
||||
next_states.reserve(2);
|
||||
next_states.push_back(std::move(encoder_out[1]));
|
||||
next_states.push_back(std::move(encoder_out[2]));
|
||||
|
||||
return {std::move(encoder_out[0]), std::move(next_states)};
|
||||
}
|
||||
|
||||
Ort::Value OnlineConformerTransducerModel::RunDecoder(
|
||||
Ort::Value decoder_input) {
|
||||
auto decoder_out = decoder_sess_->Run(
|
||||
{}, decoder_input_names_ptr_.data(), &decoder_input, 1,
|
||||
decoder_output_names_ptr_.data(), decoder_output_names_ptr_.size());
|
||||
return std::move(decoder_out[0]);
|
||||
}
|
||||
|
||||
Ort::Value OnlineConformerTransducerModel::RunJoiner(Ort::Value encoder_out,
|
||||
Ort::Value decoder_out) {
|
||||
std::array<Ort::Value, 2> joiner_input = {std::move(encoder_out),
|
||||
std::move(decoder_out)};
|
||||
auto logit =
|
||||
joiner_sess_->Run({}, joiner_input_names_ptr_.data(), joiner_input.data(),
|
||||
joiner_input.size(), joiner_output_names_ptr_.data(),
|
||||
joiner_output_names_ptr_.size());
|
||||
|
||||
return std::move(logit[0]);
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
108
sherpa-onnx/csrc/online-conformer-transducer-model.h
Normal file
108
sherpa-onnx/csrc/online-conformer-transducer-model.h
Normal file
@@ -0,0 +1,108 @@
|
||||
// sherpa-onnx/csrc/online-conformer-transducer-model.h
|
||||
//
|
||||
// Copyright (c) 2023 Jingzhao Ou (jingzhao.ou@gmail.com)
|
||||
|
||||
#ifndef SHERPA_ONNX_CSRC_ONLINE_CONFORMER_TRANSDUCER_MODEL_H_
|
||||
#define SHERPA_ONNX_CSRC_ONLINE_CONFORMER_TRANSDUCER_MODEL_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
#include "android/asset_manager.h"
|
||||
#include "android/asset_manager_jni.h"
|
||||
#endif
|
||||
|
||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||
#include "sherpa-onnx/csrc/online-transducer-model-config.h"
|
||||
#include "sherpa-onnx/csrc/online-transducer-model.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
class OnlineConformerTransducerModel : public OnlineTransducerModel {
|
||||
public:
|
||||
explicit OnlineConformerTransducerModel(
|
||||
const OnlineTransducerModelConfig &config);
|
||||
|
||||
#if __ANDROID_API__ >= 9
|
||||
OnlineConformerTransducerModel(AAssetManager *mgr,
|
||||
const OnlineTransducerModelConfig &config);
|
||||
#endif
|
||||
|
||||
std::vector<Ort::Value> StackStates(
|
||||
const std::vector<std::vector<Ort::Value>> &states) const override;
|
||||
|
||||
std::vector<std::vector<Ort::Value>> UnStackStates(
|
||||
const std::vector<Ort::Value> &states) const override;
|
||||
|
||||
std::vector<Ort::Value> GetEncoderInitStates() override;
|
||||
|
||||
std::pair<Ort::Value, std::vector<Ort::Value>> RunEncoder(
|
||||
Ort::Value features, std::vector<Ort::Value> states,
|
||||
Ort::Value processed_frames) override;
|
||||
|
||||
Ort::Value RunDecoder(Ort::Value decoder_input) override;
|
||||
|
||||
Ort::Value RunJoiner(Ort::Value encoder_out, Ort::Value decoder_out) override;
|
||||
|
||||
int32_t ContextSize() const override { return context_size_; }
|
||||
|
||||
int32_t ChunkSize() const override { return T_; }
|
||||
|
||||
int32_t ChunkShift() const override { return decode_chunk_len_; }
|
||||
|
||||
int32_t VocabSize() const override { return vocab_size_; }
|
||||
OrtAllocator *Allocator() override { return allocator_; }
|
||||
|
||||
private:
|
||||
void InitEncoder(void *model_data, size_t model_data_length);
|
||||
void InitDecoder(void *model_data, size_t model_data_length);
|
||||
void InitJoiner(void *model_data, size_t model_data_length);
|
||||
|
||||
private:
|
||||
Ort::Env env_;
|
||||
Ort::SessionOptions sess_opts_;
|
||||
Ort::AllocatorWithDefaultOptions allocator_;
|
||||
|
||||
std::unique_ptr<Ort::Session> encoder_sess_;
|
||||
std::unique_ptr<Ort::Session> decoder_sess_;
|
||||
std::unique_ptr<Ort::Session> joiner_sess_;
|
||||
|
||||
std::vector<std::string> encoder_input_names_;
|
||||
std::vector<const char *> encoder_input_names_ptr_;
|
||||
|
||||
std::vector<std::string> encoder_output_names_;
|
||||
std::vector<const char *> encoder_output_names_ptr_;
|
||||
|
||||
std::vector<std::string> decoder_input_names_;
|
||||
std::vector<const char *> decoder_input_names_ptr_;
|
||||
|
||||
std::vector<std::string> decoder_output_names_;
|
||||
std::vector<const char *> decoder_output_names_ptr_;
|
||||
|
||||
std::vector<std::string> joiner_input_names_;
|
||||
std::vector<const char *> joiner_input_names_ptr_;
|
||||
|
||||
std::vector<std::string> joiner_output_names_;
|
||||
std::vector<const char *> joiner_output_names_ptr_;
|
||||
|
||||
OnlineTransducerModelConfig config_;
|
||||
|
||||
int32_t num_encoder_layers_ = 0;
|
||||
int32_t T_ = 0;
|
||||
int32_t decode_chunk_len_ = 0;
|
||||
int32_t cnn_module_kernel_ = 0;
|
||||
int32_t context_size_ = 0;
|
||||
int32_t left_context_ = 0;
|
||||
// TODO(jingzhaoou): to retrieve from model medadata
|
||||
int32_t right_context_ = 4;
|
||||
int32_t encoder_dim_ = 0;
|
||||
int32_t pad_length_ = 0;
|
||||
int32_t vocab_size_ = 0;
|
||||
};
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_ONLINE_CONFORMER_TRANSDUCER_MODEL_H_
|
||||
@@ -227,7 +227,8 @@ std::vector<Ort::Value> OnlineLstmTransducerModel::GetEncoderInitStates() {
|
||||
|
||||
std::pair<Ort::Value, std::vector<Ort::Value>>
|
||||
OnlineLstmTransducerModel::RunEncoder(Ort::Value features,
|
||||
std::vector<Ort::Value> states) {
|
||||
std::vector<Ort::Value> states,
|
||||
Ort::Value /* processed_frames */) {
|
||||
std::array<Ort::Value, 3> encoder_inputs = {
|
||||
std::move(features), std::move(states[0]), std::move(states[1])};
|
||||
|
||||
|
||||
@@ -38,7 +38,8 @@ class OnlineLstmTransducerModel : public OnlineTransducerModel {
|
||||
std::vector<Ort::Value> GetEncoderInitStates() override;
|
||||
|
||||
std::pair<Ort::Value, std::vector<Ort::Value>> RunEncoder(
|
||||
Ort::Value features, std::vector<Ort::Value> states) override;
|
||||
Ort::Value features, std::vector<Ort::Value> states,
|
||||
Ort::Value processed_frames) override;
|
||||
|
||||
Ort::Value RunDecoder(Ort::Value decoder_input) override;
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@
|
||||
|
||||
#include <algorithm>
|
||||
#include <iomanip>
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <sstream>
|
||||
#include <utility>
|
||||
@@ -187,11 +188,14 @@ class OnlineRecognizer::Impl {
|
||||
std::vector<OnlineTransducerDecoderResult> results(n);
|
||||
std::vector<float> features_vec(n * chunk_size * feature_dim);
|
||||
std::vector<std::vector<Ort::Value>> states_vec(n);
|
||||
std::vector<int64_t> all_processed_frames(n);
|
||||
|
||||
for (int32_t i = 0; i != n; ++i) {
|
||||
const auto num_processed_frames = ss[i]->GetNumProcessedFrames();
|
||||
std::vector<float> features =
|
||||
ss[i]->GetFrames(ss[i]->GetNumProcessedFrames(), chunk_size);
|
||||
ss[i]->GetFrames(num_processed_frames, chunk_size);
|
||||
|
||||
// Question: should num_processed_frames include chunk_shift?
|
||||
ss[i]->GetNumProcessedFrames() += chunk_shift;
|
||||
|
||||
std::copy(features.begin(), features.end(),
|
||||
@@ -199,6 +203,7 @@ class OnlineRecognizer::Impl {
|
||||
|
||||
results[i] = std::move(ss[i]->GetResult());
|
||||
states_vec[i] = std::move(ss[i]->GetStates());
|
||||
all_processed_frames[i] = num_processed_frames;
|
||||
}
|
||||
|
||||
auto memory_info =
|
||||
@@ -210,9 +215,20 @@ class OnlineRecognizer::Impl {
|
||||
features_vec.size(), x_shape.data(),
|
||||
x_shape.size());
|
||||
|
||||
std::array<int64_t, 1> processed_frames_shape{
|
||||
static_cast<int64_t>(all_processed_frames.size())};
|
||||
|
||||
Ort::Value processed_frames = Ort::Value::CreateTensor(
|
||||
memory_info,
|
||||
all_processed_frames.data(),
|
||||
all_processed_frames.size(),
|
||||
processed_frames_shape.data(),
|
||||
processed_frames_shape.size());
|
||||
|
||||
auto states = model_->StackStates(states_vec);
|
||||
|
||||
auto pair = model_->RunEncoder(std::move(x), std::move(states));
|
||||
auto pair = model_->RunEncoder(
|
||||
std::move(x), std::move(states), std::move(processed_frames));
|
||||
|
||||
decoder_->Decode(std::move(pair.first), &results);
|
||||
|
||||
|
||||
@@ -10,11 +10,13 @@
|
||||
#endif
|
||||
|
||||
#include <algorithm>
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
|
||||
#include "sherpa-onnx/csrc/macros.h"
|
||||
#include "sherpa-onnx/csrc/online-conformer-transducer-model.h"
|
||||
#include "sherpa-onnx/csrc/online-lstm-transducer-model.h"
|
||||
#include "sherpa-onnx/csrc/online-zipformer-transducer-model.h"
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
@@ -22,6 +24,7 @@
|
||||
namespace {
|
||||
|
||||
enum class ModelType {
|
||||
kConformer,
|
||||
kLstm,
|
||||
kZipformer,
|
||||
kUnkown,
|
||||
@@ -57,7 +60,9 @@ static ModelType GetModelType(char *model_data, size_t model_data_length,
|
||||
return ModelType::kUnkown;
|
||||
}
|
||||
|
||||
if (model_type.get() == std::string("lstm")) {
|
||||
if (model_type.get() == std::string("conformer")) {
|
||||
return ModelType::kConformer;
|
||||
} else if (model_type.get() == std::string("lstm")) {
|
||||
return ModelType::kLstm;
|
||||
} else if (model_type.get() == std::string("zipformer")) {
|
||||
return ModelType::kZipformer;
|
||||
@@ -78,6 +83,8 @@ std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create(
|
||||
}
|
||||
|
||||
switch (model_type) {
|
||||
case ModelType::kConformer:
|
||||
return std::make_unique<OnlineConformerTransducerModel>(config);
|
||||
case ModelType::kLstm:
|
||||
return std::make_unique<OnlineLstmTransducerModel>(config);
|
||||
case ModelType::kZipformer:
|
||||
@@ -132,6 +139,8 @@ std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create(
|
||||
auto model_type = GetModelType(buffer.data(), buffer.size(), config.debug);
|
||||
|
||||
switch (model_type) {
|
||||
case ModelType::kConformer:
|
||||
return std::make_unique<OnlineConformerTransducerModel>(mgr, config);
|
||||
case ModelType::kLstm:
|
||||
return std::make_unique<OnlineLstmTransducerModel>(mgr, config);
|
||||
case ModelType::kZipformer:
|
||||
|
||||
@@ -64,6 +64,7 @@ class OnlineTransducerModel {
|
||||
*
|
||||
* @param features A tensor of shape (N, T, C). It is changed in-place.
|
||||
* @param states Encoder state of the previous chunk. It is changed in-place.
|
||||
* @param processed_frames Processed frames before subsampling. It is a 1-D tensor with data type int64_t.
|
||||
*
|
||||
* @return Return a tuple containing:
|
||||
* - encoder_out, a tensor of shape (N, T', encoder_out_dim)
|
||||
@@ -71,7 +72,8 @@ class OnlineTransducerModel {
|
||||
*/
|
||||
virtual std::pair<Ort::Value, std::vector<Ort::Value>> RunEncoder(
|
||||
Ort::Value features,
|
||||
std::vector<Ort::Value> states) = 0; // NOLINT
|
||||
std::vector<Ort::Value> states,
|
||||
Ort::Value processed_frames) = 0; // NOLINT
|
||||
|
||||
/** Run the decoder network.
|
||||
*
|
||||
|
||||
@@ -434,7 +434,8 @@ std::vector<Ort::Value> OnlineZipformerTransducerModel::GetEncoderInitStates() {
|
||||
|
||||
std::pair<Ort::Value, std::vector<Ort::Value>>
|
||||
OnlineZipformerTransducerModel::RunEncoder(Ort::Value features,
|
||||
std::vector<Ort::Value> states) {
|
||||
std::vector<Ort::Value> states,
|
||||
Ort::Value /* processed_frames */) {
|
||||
std::vector<Ort::Value> encoder_inputs;
|
||||
encoder_inputs.reserve(1 + states.size());
|
||||
|
||||
|
||||
@@ -39,7 +39,8 @@ class OnlineZipformerTransducerModel : public OnlineTransducerModel {
|
||||
std::vector<Ort::Value> GetEncoderInitStates() override;
|
||||
|
||||
std::pair<Ort::Value, std::vector<Ort::Value>> RunEncoder(
|
||||
Ort::Value features, std::vector<Ort::Value> states) override;
|
||||
Ort::Value features, std::vector<Ort::Value> states,
|
||||
Ort::Value processed_frames) override;
|
||||
|
||||
Ort::Value RunDecoder(Ort::Value decoder_input) override;
|
||||
|
||||
|
||||
@@ -168,6 +168,26 @@ void Print3D(Ort::Value *v) {
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
|
||||
void Print4D(Ort::Value *v) {
|
||||
std::vector<int64_t> shape = v->GetTensorTypeAndShapeInfo().GetShape();
|
||||
const float *d = v->GetTensorData<float>();
|
||||
|
||||
for (int32_t p = 0; p != static_cast<int32_t>(shape[0]); ++p) {
|
||||
fprintf(stderr, "---plane %d---\n", p);
|
||||
for (int32_t q = 0; q != static_cast<int32_t>(shape[1]); ++q) {
|
||||
fprintf(stderr, "---subplane %d---\n", q);
|
||||
for (int32_t r = 0; r != static_cast<int32_t>(shape[2]); ++r) {
|
||||
for (int32_t c = 0; c != static_cast<int32_t>(shape[3]); ++c, ++d) {
|
||||
fprintf(stderr, "%.3f ", *d);
|
||||
}
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
}
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
|
||||
std::vector<char> ReadFile(const std::string &filename) {
|
||||
std::ifstream input(filename, std::ios::binary);
|
||||
std::vector<char> buffer(std::istreambuf_iterator<char>(input), {});
|
||||
|
||||
@@ -75,6 +75,9 @@ void Print2D(Ort::Value *v);
|
||||
// Print a 3-D tensor to stderr
|
||||
void Print3D(Ort::Value *v);
|
||||
|
||||
// Print a 4-D tensor to stderr
|
||||
void Print4D(Ort::Value *v);
|
||||
|
||||
template <typename T = float>
|
||||
void Fill(Ort::Value *tensor, T value) {
|
||||
auto n = tensor->GetTypeInfo().GetTensorTypeAndShapeInfo().GetElementCount();
|
||||
|
||||
254
sherpa-onnx/csrc/stack-test.cc
Normal file
254
sherpa-onnx/csrc/stack-test.cc
Normal file
@@ -0,0 +1,254 @@
|
||||
// sherpa-onnx/csrc/stack-test.cc
|
||||
//
|
||||
// Copyright (c) 2023 Jingzhao Ou (jingzhao.ou@gmail.com)
|
||||
|
||||
#include "sherpa-onnx/csrc/stack.h"
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
TEST(Stack, Test1DTensors) {
|
||||
Ort::AllocatorWithDefaultOptions allocator;
|
||||
|
||||
std::array<int64_t, 1> a_shape{3};
|
||||
std::array<int64_t, 1> b_shape{3};
|
||||
|
||||
Ort::Value a = Ort::Value::CreateTensor<float>(allocator, a_shape.data(),
|
||||
a_shape.size());
|
||||
|
||||
Ort::Value b = Ort::Value::CreateTensor<float>(allocator, b_shape.data(),
|
||||
b_shape.size());
|
||||
float *pa = a.GetTensorMutableData<float>();
|
||||
float *pb = b.GetTensorMutableData<float>();
|
||||
for (int32_t i = 0; i != static_cast<int32_t>(a_shape[0]); ++i) {
|
||||
pa[i] = i;
|
||||
}
|
||||
for (int32_t i = 0; i != static_cast<int32_t>(b_shape[0]); ++i) {
|
||||
pb[i] = i + 10;
|
||||
}
|
||||
|
||||
Ort::Value ans = Stack(allocator, {&a, &b}, 0);
|
||||
|
||||
Print1D(&a);
|
||||
Print1D(&b);
|
||||
Print2D(&ans);
|
||||
|
||||
const float *pans = ans.GetTensorData<float>();
|
||||
for (int32_t i = 0; i != static_cast<int32_t>(a_shape[0]); ++i) {
|
||||
EXPECT_EQ(pa[i], pans[i]);
|
||||
}
|
||||
|
||||
for (int32_t i = 0; i != static_cast<int32_t>(b_shape[0]); ++i) {
|
||||
EXPECT_EQ(pb[i], pans[i + a_shape[0]]);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(Stack, Test2DTensorsDim0) {
|
||||
Ort::AllocatorWithDefaultOptions allocator;
|
||||
|
||||
std::array<int64_t, 2> a_shape{2, 3};
|
||||
std::array<int64_t, 2> b_shape{2, 3};
|
||||
|
||||
Ort::Value a = Ort::Value::CreateTensor<float>(
|
||||
allocator, a_shape.data(), a_shape.size());
|
||||
|
||||
Ort::Value b = Ort::Value::CreateTensor<float>(
|
||||
allocator, b_shape.data(), b_shape.size());
|
||||
|
||||
float *pa = a.GetTensorMutableData<float>();
|
||||
float *pb = b.GetTensorMutableData<float>();
|
||||
for (int32_t i = 0; i != static_cast<int32_t>(a_shape[0] * a_shape[1]); ++i) {
|
||||
pa[i] = i;
|
||||
}
|
||||
for (int32_t i = 0; i != static_cast<int32_t>(b_shape[0] * b_shape[1]); ++i) {
|
||||
pb[i] = i + 10;
|
||||
}
|
||||
|
||||
Ort::Value ans = Stack(allocator, {&a, &b}, 0);
|
||||
|
||||
Print2D(&a);
|
||||
Print2D(&b);
|
||||
Print3D(&ans);
|
||||
|
||||
const float *pans = ans.GetTensorData<float>();
|
||||
for (int32_t i = 0; i != static_cast<int32_t>(a_shape[0] * a_shape[1]); ++i) {
|
||||
EXPECT_EQ(pa[i], pans[i]);
|
||||
}
|
||||
for (int32_t i = 0; i != static_cast<int32_t>(b_shape[0] * b_shape[1]); ++i) {
|
||||
EXPECT_EQ(pb[i], pans[i + a_shape[0] * a_shape[1]]);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(Stack, Test2DTensorsDim1) {
|
||||
Ort::AllocatorWithDefaultOptions allocator;
|
||||
|
||||
std::array<int64_t, 2> a_shape{4, 3};
|
||||
std::array<int64_t, 2> b_shape{4, 3};
|
||||
|
||||
Ort::Value a = Ort::Value::CreateTensor<float>(allocator, a_shape.data(),
|
||||
a_shape.size());
|
||||
|
||||
Ort::Value b = Ort::Value::CreateTensor<float>(allocator, b_shape.data(),
|
||||
b_shape.size());
|
||||
|
||||
float *pa = a.GetTensorMutableData<float>();
|
||||
float *pb = b.GetTensorMutableData<float>();
|
||||
for (int32_t i = 0; i != static_cast<int32_t>(a_shape[0] * a_shape[1]); ++i) {
|
||||
pa[i] = i;
|
||||
}
|
||||
for (int32_t i = 0; i != static_cast<int32_t>(b_shape[0] * b_shape[1]); ++i) {
|
||||
pb[i] = i + 10;
|
||||
}
|
||||
|
||||
Ort::Value ans = Stack(allocator, {&a, &b}, 1);
|
||||
|
||||
Print2D(&a);
|
||||
Print2D(&b);
|
||||
Print3D(&ans);
|
||||
|
||||
const float *pans = ans.GetTensorData<float>();
|
||||
|
||||
for (int32_t r = 0; r != static_cast<int32_t>(a_shape[0]); ++r) {
|
||||
for (int32_t i = 0; i != static_cast<int32_t>(a_shape[1]);
|
||||
++i, ++pa, ++pans) {
|
||||
EXPECT_EQ(*pa, *pans);
|
||||
}
|
||||
|
||||
for (int32_t i = 0; i != static_cast<int32_t>(b_shape[1]);
|
||||
++i, ++pb, ++pans) {
|
||||
EXPECT_EQ(*pb, *pans);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(Stack, Test3DTensorsDim0) {
|
||||
Ort::AllocatorWithDefaultOptions allocator;
|
||||
|
||||
std::array<int64_t, 3> a_shape{2, 3, 2};
|
||||
std::array<int64_t, 3> b_shape{2, 3, 2};
|
||||
|
||||
Ort::Value a = Ort::Value::CreateTensor<float>(allocator, a_shape.data(),
|
||||
a_shape.size());
|
||||
|
||||
Ort::Value b = Ort::Value::CreateTensor<float>(allocator, b_shape.data(),
|
||||
b_shape.size());
|
||||
|
||||
float *pa = a.GetTensorMutableData<float>();
|
||||
float *pb = b.GetTensorMutableData<float>();
|
||||
for (int32_t i = 0;
|
||||
i != static_cast<int32_t>(a_shape[0] * a_shape[1] * a_shape[2]); ++i) {
|
||||
pa[i] = i;
|
||||
}
|
||||
for (int32_t i = 0;
|
||||
i != static_cast<int32_t>(b_shape[0] * b_shape[1] * b_shape[2]); ++i) {
|
||||
pb[i] = i + 10;
|
||||
}
|
||||
|
||||
Ort::Value ans = Stack(allocator, {&a, &b}, 0);
|
||||
|
||||
const float *pans = ans.GetTensorData<float>();
|
||||
for (int32_t i = 0;
|
||||
i != static_cast<int32_t>(a_shape[0] * a_shape[1] * a_shape[2]); ++i) {
|
||||
EXPECT_EQ(pa[i], pans[i]);
|
||||
}
|
||||
for (int32_t i = 0;
|
||||
i != static_cast<int32_t>(b_shape[0] * b_shape[1] * b_shape[2]); ++i) {
|
||||
EXPECT_EQ(pb[i], pans[i + a_shape[0] * a_shape[1] * a_shape[2]]);
|
||||
}
|
||||
|
||||
Print3D(&a);
|
||||
Print3D(&b);
|
||||
Print4D(&ans);
|
||||
}
|
||||
|
||||
TEST(Stack, Test3DTensorsDim1) {
|
||||
Ort::AllocatorWithDefaultOptions allocator;
|
||||
|
||||
std::array<int64_t, 3> a_shape{2, 2, 3};
|
||||
std::array<int64_t, 3> b_shape{2, 2, 3};
|
||||
|
||||
Ort::Value a = Ort::Value::CreateTensor<float>(allocator, a_shape.data(),
|
||||
a_shape.size());
|
||||
|
||||
Ort::Value b = Ort::Value::CreateTensor<float>(allocator, b_shape.data(),
|
||||
b_shape.size());
|
||||
|
||||
float *pa = a.GetTensorMutableData<float>();
|
||||
float *pb = b.GetTensorMutableData<float>();
|
||||
for (int32_t i = 0;
|
||||
i != static_cast<int32_t>(a_shape[0] * a_shape[1] * a_shape[2]); ++i) {
|
||||
pa[i] = i;
|
||||
}
|
||||
for (int32_t i = 0;
|
||||
i != static_cast<int32_t>(b_shape[0] * b_shape[1] * b_shape[2]); ++i) {
|
||||
pb[i] = i + 10;
|
||||
}
|
||||
|
||||
Ort::Value ans = Stack(allocator, {&a, &b}, 1);
|
||||
|
||||
const float *pans = ans.GetTensorData<float>();
|
||||
|
||||
for (int32_t i = 0; i != static_cast<int32_t>(a_shape[0]); ++i) {
|
||||
for (int32_t k = 0; k != static_cast<int32_t>(a_shape[1] * a_shape[2]);
|
||||
++k, ++pa, ++pans) {
|
||||
EXPECT_EQ(*pa, *pans);
|
||||
}
|
||||
|
||||
for (int32_t k = 0; k != static_cast<int32_t>(b_shape[1] * b_shape[2]);
|
||||
++k, ++pb, ++pans) {
|
||||
EXPECT_EQ(*pb, *pans);
|
||||
}
|
||||
}
|
||||
|
||||
Print3D(&a);
|
||||
Print3D(&b);
|
||||
Print4D(&ans);
|
||||
}
|
||||
|
||||
TEST(Stack, Test3DTensorsDim2) {
|
||||
Ort::AllocatorWithDefaultOptions allocator;
|
||||
|
||||
std::array<int64_t, 3> a_shape{2, 3, 4};
|
||||
std::array<int64_t, 3> b_shape{2, 3, 4};
|
||||
|
||||
Ort::Value a = Ort::Value::CreateTensor<float>(allocator, a_shape.data(),
|
||||
a_shape.size());
|
||||
|
||||
Ort::Value b = Ort::Value::CreateTensor<float>(allocator, b_shape.data(),
|
||||
b_shape.size());
|
||||
|
||||
float *pa = a.GetTensorMutableData<float>();
|
||||
float *pb = b.GetTensorMutableData<float>();
|
||||
for (int32_t i = 0;
|
||||
i != static_cast<int32_t>(a_shape[0] * a_shape[1] * a_shape[2]); ++i) {
|
||||
pa[i] = i;
|
||||
}
|
||||
for (int32_t i = 0;
|
||||
i != static_cast<int32_t>(b_shape[0] * b_shape[1] * b_shape[2]); ++i) {
|
||||
pb[i] = i + 10;
|
||||
}
|
||||
|
||||
Ort::Value ans = Stack(allocator, {&a, &b}, 2);
|
||||
|
||||
const float *pans = ans.GetTensorData<float>();
|
||||
|
||||
for (int32_t i = 0; i != static_cast<int32_t>(a_shape[0] * a_shape[1]); ++i) {
|
||||
for (int32_t k = 0; k != static_cast<int32_t>(a_shape[2]);
|
||||
++k, ++pa, ++pans) {
|
||||
EXPECT_EQ(*pa, *pans);
|
||||
}
|
||||
|
||||
for (int32_t k = 0; k != static_cast<int32_t>(b_shape[2]);
|
||||
++k, ++pb, ++pans) {
|
||||
EXPECT_EQ(*pb, *pans);
|
||||
}
|
||||
}
|
||||
|
||||
Print3D(&a);
|
||||
Print3D(&b);
|
||||
Print4D(&ans);
|
||||
}
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
101
sherpa-onnx/csrc/stack.cc
Normal file
101
sherpa-onnx/csrc/stack.cc
Normal file
@@ -0,0 +1,101 @@
|
||||
// sherpa-onnx/csrc/stack.cc
|
||||
//
|
||||
// Copyright (c) 2023 Jingzhao Ou (jingzhao.ou@gmail.com)
|
||||
|
||||
#include "sherpa-onnx/csrc/stack.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <utility>
|
||||
|
||||
#include "sherpa-onnx/csrc/onnx-utils.h"
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
static bool Compare(const std::vector<int64_t> &a,
|
||||
const std::vector<int64_t> &b) {
|
||||
if (a.size() != b.size()) return false;
|
||||
|
||||
for (int32_t i = 0; i != static_cast<int32_t>(a.size()); ++i) {
|
||||
if (a[i] != b[i]) return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static void PrintShape(const std::vector<int64_t> &a) {
|
||||
for (auto i : a) {
|
||||
fprintf(stderr, "%d ", static_cast<int32_t>(i));
|
||||
}
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
|
||||
template <typename T /*=float*/>
|
||||
Ort::Value Stack(OrtAllocator *allocator,
|
||||
const std::vector<const Ort::Value *> &values, int32_t dim) {
|
||||
std::vector<int64_t> v0_shape =
|
||||
values[0]->GetTensorTypeAndShapeInfo().GetShape();
|
||||
|
||||
for (int32_t i = 1; i != static_cast<int32_t>(values.size()); ++i) {
|
||||
auto s = values[i]->GetTensorTypeAndShapeInfo().GetShape();
|
||||
bool ret = Compare(v0_shape, s);
|
||||
if (!ret) {
|
||||
fprintf(stderr, "Incorrect shape in Stack !\n");
|
||||
|
||||
fprintf(stderr, "Shape for tensor 0: ");
|
||||
PrintShape(v0_shape);
|
||||
|
||||
fprintf(stderr, "Shape for tensor %d: ", i);
|
||||
PrintShape(s);
|
||||
|
||||
exit(-1);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<int64_t> ans_shape;
|
||||
ans_shape.reserve(v0_shape.size() + 1);
|
||||
ans_shape.insert(ans_shape.end(), v0_shape.data(), v0_shape.data() + dim);
|
||||
ans_shape.push_back(values.size());
|
||||
ans_shape.insert(
|
||||
ans_shape.end(),
|
||||
v0_shape.data() + dim,
|
||||
v0_shape.data() + v0_shape.size());
|
||||
|
||||
auto leading_size = static_cast<int32_t>(std::accumulate(
|
||||
v0_shape.begin(), v0_shape.begin() + dim, 1, std::multiplies<int64_t>()));
|
||||
|
||||
auto trailing_size = static_cast<int32_t>(
|
||||
std::accumulate(v0_shape.begin() + dim,
|
||||
v0_shape.end(), 1,
|
||||
std::multiplies<int64_t>()));
|
||||
|
||||
Ort::Value ans = Ort::Value::CreateTensor<T>(
|
||||
allocator, ans_shape.data(), ans_shape.size());
|
||||
T *dst = ans.GetTensorMutableData<T>();
|
||||
|
||||
for (int32_t i = 0; i != leading_size; ++i) {
|
||||
for (int32_t n = 0; n != static_cast<int32_t>(values.size()); ++n) {
|
||||
const T *src = values[n]->GetTensorData<T>();
|
||||
src += i * trailing_size;
|
||||
|
||||
std::copy(src, src + trailing_size, dst);
|
||||
dst += trailing_size;
|
||||
}
|
||||
}
|
||||
|
||||
return ans;
|
||||
}
|
||||
|
||||
template Ort::Value Stack<float>(
|
||||
OrtAllocator *allocator,
|
||||
const std::vector<const Ort::Value *> &values,
|
||||
int32_t dim);
|
||||
|
||||
template Ort::Value Stack<int64_t>(
|
||||
OrtAllocator *allocator,
|
||||
const std::vector<const Ort::Value *> &values,
|
||||
int32_t dim);
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
29
sherpa-onnx/csrc/stack.h
Normal file
29
sherpa-onnx/csrc/stack.h
Normal file
@@ -0,0 +1,29 @@
|
||||
// sherpa-onnx/csrc/stack.h
|
||||
//
|
||||
// Copyright (c) 2023 Jingzhao Ou (jingzhao.ou@gmail.com)
|
||||
|
||||
#ifndef SHERPA_ONNX_CSRC_STACK_H_
|
||||
#define SHERPA_ONNX_CSRC_STACK_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||
|
||||
namespace sherpa_onnx {
|
||||
|
||||
/** Stack a list of tensors along the given dim.
|
||||
*
|
||||
* @param allocator Allocator to allocate space for the returned tensor
|
||||
* @param values Pointer to a list of tensors. The shape of the tensor must
|
||||
* be the same except on the dim to be stacked.
|
||||
* @param dim The dim along which to concatenate the input tensors
|
||||
*
|
||||
* @return Return the stacked tensor
|
||||
*/
|
||||
template <typename T = float>
|
||||
Ort::Value Stack(OrtAllocator *allocator,
|
||||
const std::vector<const Ort::Value *> &values, int32_t dim);
|
||||
|
||||
} // namespace sherpa_onnx
|
||||
|
||||
#endif // SHERPA_ONNX_CSRC_STACK_H_
|
||||
Reference in New Issue
Block a user