diff --git a/sherpa-onnx/csrc/CMakeLists.txt b/sherpa-onnx/csrc/CMakeLists.txt index 3216af65..bef262f9 100644 --- a/sherpa-onnx/csrc/CMakeLists.txt +++ b/sherpa-onnx/csrc/CMakeLists.txt @@ -34,6 +34,7 @@ set(sources offline-transducer-model-config.cc offline-transducer-model.cc offline-transducer-modified-beam-search-decoder.cc + online-conformer-transducer-model.cc online-lm.cc online-lm-config.cc online-lstm-transducer-model.cc @@ -52,6 +53,7 @@ set(sources parse-options.cc resample.cc slice.cc + stack.cc symbol-table.cc text-utils.cc transpose.cc @@ -241,6 +243,7 @@ if(SHERPA_ONNX_ENABLE_TESTS) packed-sequence-test.cc pad-sequence-test.cc slice-test.cc + stack-test.cc transpose-test.cc unbind-test.cc ) diff --git a/sherpa-onnx/csrc/online-conformer-transducer-model.cc b/sherpa-onnx/csrc/online-conformer-transducer-model.cc new file mode 100644 index 00000000..8584f0ec --- /dev/null +++ b/sherpa-onnx/csrc/online-conformer-transducer-model.cc @@ -0,0 +1,279 @@ +// sherpa-onnx/csrc/online-conformer-transducer-model.cc +// +// Copyright (c) 2023 Jingzhao Ou (jingzhao.ou@gmail.com) + +#include "sherpa-onnx/csrc/online-conformer-transducer-model.h" + +#include + +#include +#include +#include +#include +#include +#include +#include + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#include "onnxruntime_cxx_api.h" // NOLINT +#include "sherpa-onnx/csrc/cat.h" +#include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/csrc/online-transducer-decoder.h" +#include "sherpa-onnx/csrc/onnx-utils.h" +#include "sherpa-onnx/csrc/text-utils.h" +#include "sherpa-onnx/csrc/unbind.h" + +namespace sherpa_onnx { + +OnlineConformerTransducerModel::OnlineConformerTransducerModel( + const OnlineTransducerModelConfig &config) + : env_(ORT_LOGGING_LEVEL_WARNING), + config_(config), + sess_opts_{}, + allocator_{} { + sess_opts_.SetIntraOpNumThreads(config.num_threads); + sess_opts_.SetInterOpNumThreads(config.num_threads); + + { + auto buf = ReadFile(config.encoder_filename); + InitEncoder(buf.data(), buf.size()); + } + + { + auto buf = ReadFile(config.decoder_filename); + InitDecoder(buf.data(), buf.size()); + } + + { + auto buf = ReadFile(config.joiner_filename); + InitJoiner(buf.data(), buf.size()); + } +} + +#if __ANDROID_API__ >= 9 +OnlineConformerTransducerModel::OnlineConformerTransducerModel( + AAssetManager *mgr, const OnlineTransducerModelConfig &config) + : env_(ORT_LOGGING_LEVEL_WARNING), + config_(config), + sess_opts_{}, + allocator_{} { + sess_opts_.SetIntraOpNumThreads(config.num_threads); + sess_opts_.SetInterOpNumThreads(config.num_threads); + + { + auto buf = ReadFile(mgr, config.encoder_filename); + InitEncoder(buf.data(), buf.size()); + } + + { + auto buf = ReadFile(mgr, config.decoder_filename); + InitDecoder(buf.data(), buf.size()); + } + + { + auto buf = ReadFile(mgr, config.joiner_filename); + InitJoiner(buf.data(), buf.size()); + } +} +#endif + +void OnlineConformerTransducerModel::InitEncoder(void *model_data, + size_t model_data_length) { + encoder_sess_ = std::make_unique(env_, model_data, + model_data_length, sess_opts_); + + GetInputNames(encoder_sess_.get(), &encoder_input_names_, + &encoder_input_names_ptr_); + + GetOutputNames(encoder_sess_.get(), &encoder_output_names_, + &encoder_output_names_ptr_); + + // get meta data + Ort::ModelMetadata meta_data = encoder_sess_->GetModelMetadata(); + if (config_.debug) { + std::ostringstream os; + os << "---encoder---\n"; + PrintModelMetadata(os, meta_data); + SHERPA_ONNX_LOGE("%s", os.str().c_str()); + } + + Ort::AllocatorWithDefaultOptions allocator; // used in the macro below + SHERPA_ONNX_READ_META_DATA(num_encoder_layers_, "num_encoder_layers"); + SHERPA_ONNX_READ_META_DATA(T_, "T"); + SHERPA_ONNX_READ_META_DATA(decode_chunk_len_, "decode_chunk_len"); + SHERPA_ONNX_READ_META_DATA(left_context_, "left_context"); + SHERPA_ONNX_READ_META_DATA(encoder_dim_, "encoder_dim"); + SHERPA_ONNX_READ_META_DATA(pad_length_, "pad_length"); + SHERPA_ONNX_READ_META_DATA(cnn_module_kernel_, "cnn_module_kernel"); +} + +void OnlineConformerTransducerModel::InitDecoder(void *model_data, + size_t model_data_length) { + decoder_sess_ = std::make_unique(env_, model_data, + model_data_length, sess_opts_); + + GetInputNames(decoder_sess_.get(), &decoder_input_names_, + &decoder_input_names_ptr_); + + GetOutputNames(decoder_sess_.get(), &decoder_output_names_, + &decoder_output_names_ptr_); + + // get meta data + Ort::ModelMetadata meta_data = decoder_sess_->GetModelMetadata(); + if (config_.debug) { + std::ostringstream os; + os << "---decoder---\n"; + PrintModelMetadata(os, meta_data); + SHERPA_ONNX_LOGE("%s", os.str().c_str()); + } + + Ort::AllocatorWithDefaultOptions allocator; // used in the macro below + SHERPA_ONNX_READ_META_DATA(vocab_size_, "vocab_size"); + SHERPA_ONNX_READ_META_DATA(context_size_, "context_size"); +} + +void OnlineConformerTransducerModel::InitJoiner(void *model_data, + size_t model_data_length) { + joiner_sess_ = std::make_unique(env_, model_data, + model_data_length, sess_opts_); + + GetInputNames(joiner_sess_.get(), &joiner_input_names_, + &joiner_input_names_ptr_); + + GetOutputNames(joiner_sess_.get(), &joiner_output_names_, + &joiner_output_names_ptr_); + + // get meta data + Ort::ModelMetadata meta_data = joiner_sess_->GetModelMetadata(); + if (config_.debug) { + std::ostringstream os; + os << "---joiner---\n"; + PrintModelMetadata(os, meta_data); + SHERPA_ONNX_LOGE("%s", os.str().c_str()); + } +} + +std::vector OnlineConformerTransducerModel::StackStates( + const std::vector> &states) const { + int32_t batch_size = static_cast(states.size()); + + std::vector attn_vec(batch_size); + std::vector conv_vec(batch_size); + + for (int32_t i = 0; i != batch_size; ++i) { + assert(states[i].size() == 2); + attn_vec[i] = &states[i][0]; + conv_vec[i] = &states[i][1]; + } + + Ort::Value attn = Cat(allocator_, attn_vec, 2); + Ort::Value conv = Cat(allocator_, conv_vec, 2); + + std::vector ans; + ans.reserve(2); + ans.push_back(std::move(attn)); + ans.push_back(std::move(conv)); + + return ans; +} + +std::vector> +OnlineConformerTransducerModel::UnStackStates( + const std::vector &states) const { + const int32_t batch_size = + states[0].GetTensorTypeAndShapeInfo().GetShape()[2]; + assert(states.size() == 2); + + std::vector> ans(batch_size); + + std::vector attn_vec = Unbind(allocator_, &states[0], 2); + std::vector conv_vec = Unbind(allocator_, &states[1], 2); + + assert(attn_vec.size() == batch_size); + assert(conv_vec.size() == batch_size); + + for (int32_t i = 0; i != batch_size; ++i) { + ans[i].push_back(std::move(attn_vec[i])); + ans[i].push_back(std::move(conv_vec[i])); + } + + return ans; +} + +std::vector OnlineConformerTransducerModel::GetEncoderInitStates() { + // Please see + // https://github.com/k2-fsa/icefall/blob/86b0db6eb9c84d9bc90a71d92774fe2a7f73e6ab/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py#L203 + // for details + constexpr int32_t kBatchSize = 1; + std::array h_shape{ + num_encoder_layers_, left_context_, kBatchSize, encoder_dim_}; + Ort::Value h = Ort::Value::CreateTensor(allocator_, h_shape.data(), + h_shape.size()); + + Fill(&h, 0); + + std::array c_shape{num_encoder_layers_, cnn_module_kernel_ - 1, + kBatchSize, encoder_dim_}; + + Ort::Value c = Ort::Value::CreateTensor(allocator_, c_shape.data(), + c_shape.size()); + + Fill(&c, 0); + + std::vector states; + + states.reserve(2); + states.push_back(std::move(h)); + states.push_back(std::move(c)); + + return states; +} + +std::pair> +OnlineConformerTransducerModel::RunEncoder(Ort::Value features, + std::vector states, + Ort::Value processed_frames) { + std::array encoder_inputs = { + std::move(features), + std::move(states[0]), + std::move(states[1]), + std::move(processed_frames)}; + + auto encoder_out = encoder_sess_->Run( + {}, encoder_input_names_ptr_.data(), encoder_inputs.data(), + encoder_inputs.size(), encoder_output_names_ptr_.data(), + encoder_output_names_ptr_.size()); + + std::vector next_states; + next_states.reserve(2); + next_states.push_back(std::move(encoder_out[1])); + next_states.push_back(std::move(encoder_out[2])); + + return {std::move(encoder_out[0]), std::move(next_states)}; +} + +Ort::Value OnlineConformerTransducerModel::RunDecoder( + Ort::Value decoder_input) { + auto decoder_out = decoder_sess_->Run( + {}, decoder_input_names_ptr_.data(), &decoder_input, 1, + decoder_output_names_ptr_.data(), decoder_output_names_ptr_.size()); + return std::move(decoder_out[0]); +} + +Ort::Value OnlineConformerTransducerModel::RunJoiner(Ort::Value encoder_out, + Ort::Value decoder_out) { + std::array joiner_input = {std::move(encoder_out), + std::move(decoder_out)}; + auto logit = + joiner_sess_->Run({}, joiner_input_names_ptr_.data(), joiner_input.data(), + joiner_input.size(), joiner_output_names_ptr_.data(), + joiner_output_names_ptr_.size()); + + return std::move(logit[0]); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/online-conformer-transducer-model.h b/sherpa-onnx/csrc/online-conformer-transducer-model.h new file mode 100644 index 00000000..f60ed53c --- /dev/null +++ b/sherpa-onnx/csrc/online-conformer-transducer-model.h @@ -0,0 +1,108 @@ +// sherpa-onnx/csrc/online-conformer-transducer-model.h +// +// Copyright (c) 2023 Jingzhao Ou (jingzhao.ou@gmail.com) + +#ifndef SHERPA_ONNX_CSRC_ONLINE_CONFORMER_TRANSDUCER_MODEL_H_ +#define SHERPA_ONNX_CSRC_ONLINE_CONFORMER_TRANSDUCER_MODEL_H_ + +#include +#include +#include +#include + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#include "onnxruntime_cxx_api.h" // NOLINT +#include "sherpa-onnx/csrc/online-transducer-model-config.h" +#include "sherpa-onnx/csrc/online-transducer-model.h" + +namespace sherpa_onnx { + +class OnlineConformerTransducerModel : public OnlineTransducerModel { + public: + explicit OnlineConformerTransducerModel( + const OnlineTransducerModelConfig &config); + +#if __ANDROID_API__ >= 9 + OnlineConformerTransducerModel(AAssetManager *mgr, + const OnlineTransducerModelConfig &config); +#endif + + std::vector StackStates( + const std::vector> &states) const override; + + std::vector> UnStackStates( + const std::vector &states) const override; + + std::vector GetEncoderInitStates() override; + + std::pair> RunEncoder( + Ort::Value features, std::vector states, + Ort::Value processed_frames) override; + + Ort::Value RunDecoder(Ort::Value decoder_input) override; + + Ort::Value RunJoiner(Ort::Value encoder_out, Ort::Value decoder_out) override; + + int32_t ContextSize() const override { return context_size_; } + + int32_t ChunkSize() const override { return T_; } + + int32_t ChunkShift() const override { return decode_chunk_len_; } + + int32_t VocabSize() const override { return vocab_size_; } + OrtAllocator *Allocator() override { return allocator_; } + + private: + void InitEncoder(void *model_data, size_t model_data_length); + void InitDecoder(void *model_data, size_t model_data_length); + void InitJoiner(void *model_data, size_t model_data_length); + + private: + Ort::Env env_; + Ort::SessionOptions sess_opts_; + Ort::AllocatorWithDefaultOptions allocator_; + + std::unique_ptr encoder_sess_; + std::unique_ptr decoder_sess_; + std::unique_ptr joiner_sess_; + + std::vector encoder_input_names_; + std::vector encoder_input_names_ptr_; + + std::vector encoder_output_names_; + std::vector encoder_output_names_ptr_; + + std::vector decoder_input_names_; + std::vector decoder_input_names_ptr_; + + std::vector decoder_output_names_; + std::vector decoder_output_names_ptr_; + + std::vector joiner_input_names_; + std::vector joiner_input_names_ptr_; + + std::vector joiner_output_names_; + std::vector joiner_output_names_ptr_; + + OnlineTransducerModelConfig config_; + + int32_t num_encoder_layers_ = 0; + int32_t T_ = 0; + int32_t decode_chunk_len_ = 0; + int32_t cnn_module_kernel_ = 0; + int32_t context_size_ = 0; + int32_t left_context_ = 0; + // TODO(jingzhaoou): to retrieve from model medadata + int32_t right_context_ = 4; + int32_t encoder_dim_ = 0; + int32_t pad_length_ = 0; + int32_t vocab_size_ = 0; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_ONLINE_CONFORMER_TRANSDUCER_MODEL_H_ diff --git a/sherpa-onnx/csrc/online-lstm-transducer-model.cc b/sherpa-onnx/csrc/online-lstm-transducer-model.cc index 25e300ae..ee40b10c 100644 --- a/sherpa-onnx/csrc/online-lstm-transducer-model.cc +++ b/sherpa-onnx/csrc/online-lstm-transducer-model.cc @@ -227,7 +227,8 @@ std::vector OnlineLstmTransducerModel::GetEncoderInitStates() { std::pair> OnlineLstmTransducerModel::RunEncoder(Ort::Value features, - std::vector states) { + std::vector states, + Ort::Value /* processed_frames */) { std::array encoder_inputs = { std::move(features), std::move(states[0]), std::move(states[1])}; diff --git a/sherpa-onnx/csrc/online-lstm-transducer-model.h b/sherpa-onnx/csrc/online-lstm-transducer-model.h index 5b6ad282..ab673a19 100644 --- a/sherpa-onnx/csrc/online-lstm-transducer-model.h +++ b/sherpa-onnx/csrc/online-lstm-transducer-model.h @@ -38,7 +38,8 @@ class OnlineLstmTransducerModel : public OnlineTransducerModel { std::vector GetEncoderInitStates() override; std::pair> RunEncoder( - Ort::Value features, std::vector states) override; + Ort::Value features, std::vector states, + Ort::Value processed_frames) override; Ort::Value RunDecoder(Ort::Value decoder_input) override; diff --git a/sherpa-onnx/csrc/online-recognizer.cc b/sherpa-onnx/csrc/online-recognizer.cc index 0cd68653..b1126cd3 100644 --- a/sherpa-onnx/csrc/online-recognizer.cc +++ b/sherpa-onnx/csrc/online-recognizer.cc @@ -9,6 +9,7 @@ #include #include +#include #include #include #include @@ -187,11 +188,14 @@ class OnlineRecognizer::Impl { std::vector results(n); std::vector features_vec(n * chunk_size * feature_dim); std::vector> states_vec(n); + std::vector all_processed_frames(n); for (int32_t i = 0; i != n; ++i) { + const auto num_processed_frames = ss[i]->GetNumProcessedFrames(); std::vector features = - ss[i]->GetFrames(ss[i]->GetNumProcessedFrames(), chunk_size); + ss[i]->GetFrames(num_processed_frames, chunk_size); + // Question: should num_processed_frames include chunk_shift? ss[i]->GetNumProcessedFrames() += chunk_shift; std::copy(features.begin(), features.end(), @@ -199,6 +203,7 @@ class OnlineRecognizer::Impl { results[i] = std::move(ss[i]->GetResult()); states_vec[i] = std::move(ss[i]->GetStates()); + all_processed_frames[i] = num_processed_frames; } auto memory_info = @@ -210,9 +215,20 @@ class OnlineRecognizer::Impl { features_vec.size(), x_shape.data(), x_shape.size()); + std::array processed_frames_shape{ + static_cast(all_processed_frames.size())}; + + Ort::Value processed_frames = Ort::Value::CreateTensor( + memory_info, + all_processed_frames.data(), + all_processed_frames.size(), + processed_frames_shape.data(), + processed_frames_shape.size()); + auto states = model_->StackStates(states_vec); - auto pair = model_->RunEncoder(std::move(x), std::move(states)); + auto pair = model_->RunEncoder( + std::move(x), std::move(states), std::move(processed_frames)); decoder_->Decode(std::move(pair.first), &results); diff --git a/sherpa-onnx/csrc/online-transducer-model.cc b/sherpa-onnx/csrc/online-transducer-model.cc index 89ad630e..75fb3c56 100644 --- a/sherpa-onnx/csrc/online-transducer-model.cc +++ b/sherpa-onnx/csrc/online-transducer-model.cc @@ -10,11 +10,13 @@ #endif #include +#include #include #include #include #include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/csrc/online-conformer-transducer-model.h" #include "sherpa-onnx/csrc/online-lstm-transducer-model.h" #include "sherpa-onnx/csrc/online-zipformer-transducer-model.h" #include "sherpa-onnx/csrc/onnx-utils.h" @@ -22,6 +24,7 @@ namespace { enum class ModelType { + kConformer, kLstm, kZipformer, kUnkown, @@ -57,7 +60,9 @@ static ModelType GetModelType(char *model_data, size_t model_data_length, return ModelType::kUnkown; } - if (model_type.get() == std::string("lstm")) { + if (model_type.get() == std::string("conformer")) { + return ModelType::kConformer; + } else if (model_type.get() == std::string("lstm")) { return ModelType::kLstm; } else if (model_type.get() == std::string("zipformer")) { return ModelType::kZipformer; @@ -78,6 +83,8 @@ std::unique_ptr OnlineTransducerModel::Create( } switch (model_type) { + case ModelType::kConformer: + return std::make_unique(config); case ModelType::kLstm: return std::make_unique(config); case ModelType::kZipformer: @@ -132,6 +139,8 @@ std::unique_ptr OnlineTransducerModel::Create( auto model_type = GetModelType(buffer.data(), buffer.size(), config.debug); switch (model_type) { + case ModelType::kConformer: + return std::make_unique(mgr, config); case ModelType::kLstm: return std::make_unique(mgr, config); case ModelType::kZipformer: diff --git a/sherpa-onnx/csrc/online-transducer-model.h b/sherpa-onnx/csrc/online-transducer-model.h index 3a034516..42539de9 100644 --- a/sherpa-onnx/csrc/online-transducer-model.h +++ b/sherpa-onnx/csrc/online-transducer-model.h @@ -64,6 +64,7 @@ class OnlineTransducerModel { * * @param features A tensor of shape (N, T, C). It is changed in-place. * @param states Encoder state of the previous chunk. It is changed in-place. + * @param processed_frames Processed frames before subsampling. It is a 1-D tensor with data type int64_t. * * @return Return a tuple containing: * - encoder_out, a tensor of shape (N, T', encoder_out_dim) @@ -71,7 +72,8 @@ class OnlineTransducerModel { */ virtual std::pair> RunEncoder( Ort::Value features, - std::vector states) = 0; // NOLINT + std::vector states, + Ort::Value processed_frames) = 0; // NOLINT /** Run the decoder network. * diff --git a/sherpa-onnx/csrc/online-zipformer-transducer-model.cc b/sherpa-onnx/csrc/online-zipformer-transducer-model.cc index 4265a76f..7af95cc9 100644 --- a/sherpa-onnx/csrc/online-zipformer-transducer-model.cc +++ b/sherpa-onnx/csrc/online-zipformer-transducer-model.cc @@ -434,7 +434,8 @@ std::vector OnlineZipformerTransducerModel::GetEncoderInitStates() { std::pair> OnlineZipformerTransducerModel::RunEncoder(Ort::Value features, - std::vector states) { + std::vector states, + Ort::Value /* processed_frames */) { std::vector encoder_inputs; encoder_inputs.reserve(1 + states.size()); diff --git a/sherpa-onnx/csrc/online-zipformer-transducer-model.h b/sherpa-onnx/csrc/online-zipformer-transducer-model.h index c2f237a3..b3e1966a 100644 --- a/sherpa-onnx/csrc/online-zipformer-transducer-model.h +++ b/sherpa-onnx/csrc/online-zipformer-transducer-model.h @@ -39,7 +39,8 @@ class OnlineZipformerTransducerModel : public OnlineTransducerModel { std::vector GetEncoderInitStates() override; std::pair> RunEncoder( - Ort::Value features, std::vector states) override; + Ort::Value features, std::vector states, + Ort::Value processed_frames) override; Ort::Value RunDecoder(Ort::Value decoder_input) override; diff --git a/sherpa-onnx/csrc/onnx-utils.cc b/sherpa-onnx/csrc/onnx-utils.cc index 99ca4416..80f01e49 100644 --- a/sherpa-onnx/csrc/onnx-utils.cc +++ b/sherpa-onnx/csrc/onnx-utils.cc @@ -168,6 +168,26 @@ void Print3D(Ort::Value *v) { fprintf(stderr, "\n"); } +void Print4D(Ort::Value *v) { + std::vector shape = v->GetTensorTypeAndShapeInfo().GetShape(); + const float *d = v->GetTensorData(); + + for (int32_t p = 0; p != static_cast(shape[0]); ++p) { + fprintf(stderr, "---plane %d---\n", p); + for (int32_t q = 0; q != static_cast(shape[1]); ++q) { + fprintf(stderr, "---subplane %d---\n", q); + for (int32_t r = 0; r != static_cast(shape[2]); ++r) { + for (int32_t c = 0; c != static_cast(shape[3]); ++c, ++d) { + fprintf(stderr, "%.3f ", *d); + } + fprintf(stderr, "\n"); + } + fprintf(stderr, "\n"); + } + } + fprintf(stderr, "\n"); +} + std::vector ReadFile(const std::string &filename) { std::ifstream input(filename, std::ios::binary); std::vector buffer(std::istreambuf_iterator(input), {}); diff --git a/sherpa-onnx/csrc/onnx-utils.h b/sherpa-onnx/csrc/onnx-utils.h index 34ebc92e..3dc0e0fc 100644 --- a/sherpa-onnx/csrc/onnx-utils.h +++ b/sherpa-onnx/csrc/onnx-utils.h @@ -75,6 +75,9 @@ void Print2D(Ort::Value *v); // Print a 3-D tensor to stderr void Print3D(Ort::Value *v); +// Print a 4-D tensor to stderr +void Print4D(Ort::Value *v); + template void Fill(Ort::Value *tensor, T value) { auto n = tensor->GetTypeInfo().GetTensorTypeAndShapeInfo().GetElementCount(); diff --git a/sherpa-onnx/csrc/stack-test.cc b/sherpa-onnx/csrc/stack-test.cc new file mode 100644 index 00000000..45a8dfaa --- /dev/null +++ b/sherpa-onnx/csrc/stack-test.cc @@ -0,0 +1,254 @@ +// sherpa-onnx/csrc/stack-test.cc +// +// Copyright (c) 2023 Jingzhao Ou (jingzhao.ou@gmail.com) + +#include "sherpa-onnx/csrc/stack.h" + +#include "gtest/gtest.h" +#include "sherpa-onnx/csrc/onnx-utils.h" + +namespace sherpa_onnx { + +TEST(Stack, Test1DTensors) { + Ort::AllocatorWithDefaultOptions allocator; + + std::array a_shape{3}; + std::array b_shape{3}; + + Ort::Value a = Ort::Value::CreateTensor(allocator, a_shape.data(), + a_shape.size()); + + Ort::Value b = Ort::Value::CreateTensor(allocator, b_shape.data(), + b_shape.size()); + float *pa = a.GetTensorMutableData(); + float *pb = b.GetTensorMutableData(); + for (int32_t i = 0; i != static_cast(a_shape[0]); ++i) { + pa[i] = i; + } + for (int32_t i = 0; i != static_cast(b_shape[0]); ++i) { + pb[i] = i + 10; + } + + Ort::Value ans = Stack(allocator, {&a, &b}, 0); + + Print1D(&a); + Print1D(&b); + Print2D(&ans); + + const float *pans = ans.GetTensorData(); + for (int32_t i = 0; i != static_cast(a_shape[0]); ++i) { + EXPECT_EQ(pa[i], pans[i]); + } + + for (int32_t i = 0; i != static_cast(b_shape[0]); ++i) { + EXPECT_EQ(pb[i], pans[i + a_shape[0]]); + } +} + +TEST(Stack, Test2DTensorsDim0) { + Ort::AllocatorWithDefaultOptions allocator; + + std::array a_shape{2, 3}; + std::array b_shape{2, 3}; + + Ort::Value a = Ort::Value::CreateTensor( + allocator, a_shape.data(), a_shape.size()); + + Ort::Value b = Ort::Value::CreateTensor( + allocator, b_shape.data(), b_shape.size()); + + float *pa = a.GetTensorMutableData(); + float *pb = b.GetTensorMutableData(); + for (int32_t i = 0; i != static_cast(a_shape[0] * a_shape[1]); ++i) { + pa[i] = i; + } + for (int32_t i = 0; i != static_cast(b_shape[0] * b_shape[1]); ++i) { + pb[i] = i + 10; + } + + Ort::Value ans = Stack(allocator, {&a, &b}, 0); + + Print2D(&a); + Print2D(&b); + Print3D(&ans); + + const float *pans = ans.GetTensorData(); + for (int32_t i = 0; i != static_cast(a_shape[0] * a_shape[1]); ++i) { + EXPECT_EQ(pa[i], pans[i]); + } + for (int32_t i = 0; i != static_cast(b_shape[0] * b_shape[1]); ++i) { + EXPECT_EQ(pb[i], pans[i + a_shape[0] * a_shape[1]]); + } +} + +TEST(Stack, Test2DTensorsDim1) { + Ort::AllocatorWithDefaultOptions allocator; + + std::array a_shape{4, 3}; + std::array b_shape{4, 3}; + + Ort::Value a = Ort::Value::CreateTensor(allocator, a_shape.data(), + a_shape.size()); + + Ort::Value b = Ort::Value::CreateTensor(allocator, b_shape.data(), + b_shape.size()); + + float *pa = a.GetTensorMutableData(); + float *pb = b.GetTensorMutableData(); + for (int32_t i = 0; i != static_cast(a_shape[0] * a_shape[1]); ++i) { + pa[i] = i; + } + for (int32_t i = 0; i != static_cast(b_shape[0] * b_shape[1]); ++i) { + pb[i] = i + 10; + } + + Ort::Value ans = Stack(allocator, {&a, &b}, 1); + + Print2D(&a); + Print2D(&b); + Print3D(&ans); + + const float *pans = ans.GetTensorData(); + + for (int32_t r = 0; r != static_cast(a_shape[0]); ++r) { + for (int32_t i = 0; i != static_cast(a_shape[1]); + ++i, ++pa, ++pans) { + EXPECT_EQ(*pa, *pans); + } + + for (int32_t i = 0; i != static_cast(b_shape[1]); + ++i, ++pb, ++pans) { + EXPECT_EQ(*pb, *pans); + } + } +} + +TEST(Stack, Test3DTensorsDim0) { + Ort::AllocatorWithDefaultOptions allocator; + + std::array a_shape{2, 3, 2}; + std::array b_shape{2, 3, 2}; + + Ort::Value a = Ort::Value::CreateTensor(allocator, a_shape.data(), + a_shape.size()); + + Ort::Value b = Ort::Value::CreateTensor(allocator, b_shape.data(), + b_shape.size()); + + float *pa = a.GetTensorMutableData(); + float *pb = b.GetTensorMutableData(); + for (int32_t i = 0; + i != static_cast(a_shape[0] * a_shape[1] * a_shape[2]); ++i) { + pa[i] = i; + } + for (int32_t i = 0; + i != static_cast(b_shape[0] * b_shape[1] * b_shape[2]); ++i) { + pb[i] = i + 10; + } + + Ort::Value ans = Stack(allocator, {&a, &b}, 0); + + const float *pans = ans.GetTensorData(); + for (int32_t i = 0; + i != static_cast(a_shape[0] * a_shape[1] * a_shape[2]); ++i) { + EXPECT_EQ(pa[i], pans[i]); + } + for (int32_t i = 0; + i != static_cast(b_shape[0] * b_shape[1] * b_shape[2]); ++i) { + EXPECT_EQ(pb[i], pans[i + a_shape[0] * a_shape[1] * a_shape[2]]); + } + + Print3D(&a); + Print3D(&b); + Print4D(&ans); +} + +TEST(Stack, Test3DTensorsDim1) { + Ort::AllocatorWithDefaultOptions allocator; + + std::array a_shape{2, 2, 3}; + std::array b_shape{2, 2, 3}; + + Ort::Value a = Ort::Value::CreateTensor(allocator, a_shape.data(), + a_shape.size()); + + Ort::Value b = Ort::Value::CreateTensor(allocator, b_shape.data(), + b_shape.size()); + + float *pa = a.GetTensorMutableData(); + float *pb = b.GetTensorMutableData(); + for (int32_t i = 0; + i != static_cast(a_shape[0] * a_shape[1] * a_shape[2]); ++i) { + pa[i] = i; + } + for (int32_t i = 0; + i != static_cast(b_shape[0] * b_shape[1] * b_shape[2]); ++i) { + pb[i] = i + 10; + } + + Ort::Value ans = Stack(allocator, {&a, &b}, 1); + + const float *pans = ans.GetTensorData(); + + for (int32_t i = 0; i != static_cast(a_shape[0]); ++i) { + for (int32_t k = 0; k != static_cast(a_shape[1] * a_shape[2]); + ++k, ++pa, ++pans) { + EXPECT_EQ(*pa, *pans); + } + + for (int32_t k = 0; k != static_cast(b_shape[1] * b_shape[2]); + ++k, ++pb, ++pans) { + EXPECT_EQ(*pb, *pans); + } + } + + Print3D(&a); + Print3D(&b); + Print4D(&ans); +} + +TEST(Stack, Test3DTensorsDim2) { + Ort::AllocatorWithDefaultOptions allocator; + + std::array a_shape{2, 3, 4}; + std::array b_shape{2, 3, 4}; + + Ort::Value a = Ort::Value::CreateTensor(allocator, a_shape.data(), + a_shape.size()); + + Ort::Value b = Ort::Value::CreateTensor(allocator, b_shape.data(), + b_shape.size()); + + float *pa = a.GetTensorMutableData(); + float *pb = b.GetTensorMutableData(); + for (int32_t i = 0; + i != static_cast(a_shape[0] * a_shape[1] * a_shape[2]); ++i) { + pa[i] = i; + } + for (int32_t i = 0; + i != static_cast(b_shape[0] * b_shape[1] * b_shape[2]); ++i) { + pb[i] = i + 10; + } + + Ort::Value ans = Stack(allocator, {&a, &b}, 2); + + const float *pans = ans.GetTensorData(); + + for (int32_t i = 0; i != static_cast(a_shape[0] * a_shape[1]); ++i) { + for (int32_t k = 0; k != static_cast(a_shape[2]); + ++k, ++pa, ++pans) { + EXPECT_EQ(*pa, *pans); + } + + for (int32_t k = 0; k != static_cast(b_shape[2]); + ++k, ++pb, ++pans) { + EXPECT_EQ(*pb, *pans); + } + } + + Print3D(&a); + Print3D(&b); + Print4D(&ans); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/stack.cc b/sherpa-onnx/csrc/stack.cc new file mode 100644 index 00000000..c7ae6bee --- /dev/null +++ b/sherpa-onnx/csrc/stack.cc @@ -0,0 +1,101 @@ +// sherpa-onnx/csrc/stack.cc +// +// Copyright (c) 2023 Jingzhao Ou (jingzhao.ou@gmail.com) + +#include "sherpa-onnx/csrc/stack.h" + +#include +#include +#include +#include +#include + +#include "sherpa-onnx/csrc/onnx-utils.h" + +namespace sherpa_onnx { + +static bool Compare(const std::vector &a, + const std::vector &b) { + if (a.size() != b.size()) return false; + + for (int32_t i = 0; i != static_cast(a.size()); ++i) { + if (a[i] != b[i]) return false; + } + + return true; +} + +static void PrintShape(const std::vector &a) { + for (auto i : a) { + fprintf(stderr, "%d ", static_cast(i)); + } + fprintf(stderr, "\n"); +} + +template +Ort::Value Stack(OrtAllocator *allocator, + const std::vector &values, int32_t dim) { + std::vector v0_shape = + values[0]->GetTensorTypeAndShapeInfo().GetShape(); + + for (int32_t i = 1; i != static_cast(values.size()); ++i) { + auto s = values[i]->GetTensorTypeAndShapeInfo().GetShape(); + bool ret = Compare(v0_shape, s); + if (!ret) { + fprintf(stderr, "Incorrect shape in Stack !\n"); + + fprintf(stderr, "Shape for tensor 0: "); + PrintShape(v0_shape); + + fprintf(stderr, "Shape for tensor %d: ", i); + PrintShape(s); + + exit(-1); + } + } + + std::vector ans_shape; + ans_shape.reserve(v0_shape.size() + 1); + ans_shape.insert(ans_shape.end(), v0_shape.data(), v0_shape.data() + dim); + ans_shape.push_back(values.size()); + ans_shape.insert( + ans_shape.end(), + v0_shape.data() + dim, + v0_shape.data() + v0_shape.size()); + + auto leading_size = static_cast(std::accumulate( + v0_shape.begin(), v0_shape.begin() + dim, 1, std::multiplies())); + + auto trailing_size = static_cast( + std::accumulate(v0_shape.begin() + dim, + v0_shape.end(), 1, + std::multiplies())); + + Ort::Value ans = Ort::Value::CreateTensor( + allocator, ans_shape.data(), ans_shape.size()); + T *dst = ans.GetTensorMutableData(); + + for (int32_t i = 0; i != leading_size; ++i) { + for (int32_t n = 0; n != static_cast(values.size()); ++n) { + const T *src = values[n]->GetTensorData(); + src += i * trailing_size; + + std::copy(src, src + trailing_size, dst); + dst += trailing_size; + } + } + + return ans; +} + +template Ort::Value Stack( + OrtAllocator *allocator, + const std::vector &values, + int32_t dim); + +template Ort::Value Stack( + OrtAllocator *allocator, + const std::vector &values, + int32_t dim); + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/stack.h b/sherpa-onnx/csrc/stack.h new file mode 100644 index 00000000..55f199aa --- /dev/null +++ b/sherpa-onnx/csrc/stack.h @@ -0,0 +1,29 @@ +// sherpa-onnx/csrc/stack.h +// +// Copyright (c) 2023 Jingzhao Ou (jingzhao.ou@gmail.com) + +#ifndef SHERPA_ONNX_CSRC_STACK_H_ +#define SHERPA_ONNX_CSRC_STACK_H_ + +#include + +#include "onnxruntime_cxx_api.h" // NOLINT + +namespace sherpa_onnx { + +/** Stack a list of tensors along the given dim. + * + * @param allocator Allocator to allocate space for the returned tensor + * @param values Pointer to a list of tensors. The shape of the tensor must + * be the same except on the dim to be stacked. + * @param dim The dim along which to concatenate the input tensors + * + * @return Return the stacked tensor + */ +template +Ort::Value Stack(OrtAllocator *allocator, + const std::vector &values, int32_t dim); + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_STACK_H_