diff --git a/sherpa-onnx/csrc/CMakeLists.txt b/sherpa-onnx/csrc/CMakeLists.txt index 877c31ed..9bbe3645 100644 --- a/sherpa-onnx/csrc/CMakeLists.txt +++ b/sherpa-onnx/csrc/CMakeLists.txt @@ -48,6 +48,7 @@ set(sources online-transducer-model.cc online-transducer-modified-beam-search-decoder.cc online-zipformer-transducer-model.cc + online-zipformer2-transducer-model.cc onnx-utils.cc session.cc packed-sequence.cc diff --git a/sherpa-onnx/csrc/online-transducer-model.cc b/sherpa-onnx/csrc/online-transducer-model.cc index 5d60a021..d8cb578f 100644 --- a/sherpa-onnx/csrc/online-transducer-model.cc +++ b/sherpa-onnx/csrc/online-transducer-model.cc @@ -18,6 +18,7 @@ #include "sherpa-onnx/csrc/online-conformer-transducer-model.h" #include "sherpa-onnx/csrc/online-lstm-transducer-model.h" #include "sherpa-onnx/csrc/online-zipformer-transducer-model.h" +#include "sherpa-onnx/csrc/online-zipformer2-transducer-model.h" #include "sherpa-onnx/csrc/onnx-utils.h" namespace { @@ -26,6 +27,7 @@ enum class ModelType { kConformer, kLstm, kZipformer, + kZipformer2, kUnkown, }; @@ -65,6 +67,8 @@ static ModelType GetModelType(char *model_data, size_t model_data_length, return ModelType::kLstm; } else if (model_type.get() == std::string("zipformer")) { return ModelType::kZipformer; + } else if (model_type.get() == std::string("zipformer2")) { + return ModelType::kZipformer2; } else { SHERPA_ONNX_LOGE("Unsupported model_type: %s", model_type.get()); return ModelType::kUnkown; @@ -88,6 +92,8 @@ std::unique_ptr OnlineTransducerModel::Create( return std::make_unique(config); case ModelType::kZipformer: return std::make_unique(config); + case ModelType::kZipformer2: + return std::make_unique(config); case ModelType::kUnkown: SHERPA_ONNX_LOGE("Unknown model type in online transducer!"); return nullptr; @@ -144,6 +150,8 @@ std::unique_ptr OnlineTransducerModel::Create( return std::make_unique(mgr, config); case ModelType::kZipformer: return std::make_unique(mgr, config); + case ModelType::kZipformer2: + return std::make_unique(mgr, config); case ModelType::kUnkown: SHERPA_ONNX_LOGE("Unknown model type in online transducer!"); return nullptr; diff --git a/sherpa-onnx/csrc/online-zipformer2-transducer-model.cc b/sherpa-onnx/csrc/online-zipformer2-transducer-model.cc new file mode 100644 index 00000000..0ffb6850 --- /dev/null +++ b/sherpa-onnx/csrc/online-zipformer2-transducer-model.cc @@ -0,0 +1,460 @@ +// sherpa-onnx/csrc/online-zipformer2-transducer-model.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-onnx/csrc/online-zipformer2-transducer-model.h" + +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#include "onnxruntime_cxx_api.h" // NOLINT +#include "sherpa-onnx/csrc/cat.h" +#include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/csrc/online-transducer-decoder.h" +#include "sherpa-onnx/csrc/onnx-utils.h" +#include "sherpa-onnx/csrc/session.h" +#include "sherpa-onnx/csrc/text-utils.h" +#include "sherpa-onnx/csrc/unbind.h" + +namespace sherpa_onnx { + +OnlineZipformer2TransducerModel::OnlineZipformer2TransducerModel( + const OnlineTransducerModelConfig &config) + : env_(ORT_LOGGING_LEVEL_WARNING), + config_(config), + sess_opts_(GetSessionOptions(config)), + allocator_{} { + { + auto buf = ReadFile(config.encoder_filename); + InitEncoder(buf.data(), buf.size()); + } + + { + auto buf = ReadFile(config.decoder_filename); + InitDecoder(buf.data(), buf.size()); + } + + { + auto buf = ReadFile(config.joiner_filename); + InitJoiner(buf.data(), buf.size()); + } +} + +#if __ANDROID_API__ >= 9 +OnlineZipformer2TransducerModel::OnlineZipformer2TransducerModel( + AAssetManager *mgr, const OnlineTransducerModelConfig &config) + : env_(ORT_LOGGING_LEVEL_WARNING), + config_(config), + sess_opts_(GetSessionOptions(config)), + allocator_{} { + { + auto buf = ReadFile(mgr, config.encoder_filename); + InitEncoder(buf.data(), buf.size()); + } + + { + auto buf = ReadFile(mgr, config.decoder_filename); + InitDecoder(buf.data(), buf.size()); + } + + { + auto buf = ReadFile(mgr, config.joiner_filename); + InitJoiner(buf.data(), buf.size()); + } +} +#endif + +void OnlineZipformer2TransducerModel::InitEncoder(void *model_data, + size_t model_data_length) { + encoder_sess_ = std::make_unique(env_, model_data, + model_data_length, sess_opts_); + + GetInputNames(encoder_sess_.get(), &encoder_input_names_, + &encoder_input_names_ptr_); + + GetOutputNames(encoder_sess_.get(), &encoder_output_names_, + &encoder_output_names_ptr_); + + // get meta data + Ort::ModelMetadata meta_data = encoder_sess_->GetModelMetadata(); + if (config_.debug) { + std::ostringstream os; + os << "---encoder---\n"; + PrintModelMetadata(os, meta_data); + SHERPA_ONNX_LOGE("%s", os.str().c_str()); + } + + Ort::AllocatorWithDefaultOptions allocator; // used in the macro below + SHERPA_ONNX_READ_META_DATA_VEC(encoder_dims_, "encoder_dims"); + SHERPA_ONNX_READ_META_DATA_VEC(query_head_dims_, "query_head_dims"); + SHERPA_ONNX_READ_META_DATA_VEC(value_head_dims_, "value_head_dims"); + SHERPA_ONNX_READ_META_DATA_VEC(num_heads_, "num_heads"); + SHERPA_ONNX_READ_META_DATA_VEC(num_encoder_layers_, "num_encoder_layers"); + SHERPA_ONNX_READ_META_DATA_VEC(cnn_module_kernels_, "cnn_module_kernels"); + SHERPA_ONNX_READ_META_DATA_VEC(left_context_len_, "left_context_len"); + + SHERPA_ONNX_READ_META_DATA(T_, "T"); + SHERPA_ONNX_READ_META_DATA(decode_chunk_len_, "decode_chunk_len"); + + if (config_.debug) { + auto print = [](const std::vector &v, const char *name) { + fprintf(stderr, "%s: ", name); + for (auto i : v) { + fprintf(stderr, "%d ", i); + } + fprintf(stderr, "\n"); + }; + print(encoder_dims_, "encoder_dims"); + print(query_head_dims_, "query_head_dims"); + print(value_head_dims_, "value_head_dims"); + print(num_heads_, "num_heads"); + print(num_encoder_layers_, "num_encoder_layers"); + print(cnn_module_kernels_, "cnn_module_kernels"); + print(left_context_len_, "left_context_len"); + SHERPA_ONNX_LOGE("T: %d", T_); + SHERPA_ONNX_LOGE("decode_chunk_len_: %d", decode_chunk_len_); + } +} + +void OnlineZipformer2TransducerModel::InitDecoder(void *model_data, + size_t model_data_length) { + decoder_sess_ = std::make_unique(env_, model_data, + model_data_length, sess_opts_); + + GetInputNames(decoder_sess_.get(), &decoder_input_names_, + &decoder_input_names_ptr_); + + GetOutputNames(decoder_sess_.get(), &decoder_output_names_, + &decoder_output_names_ptr_); + + // get meta data + Ort::ModelMetadata meta_data = decoder_sess_->GetModelMetadata(); + if (config_.debug) { + std::ostringstream os; + os << "---decoder---\n"; + PrintModelMetadata(os, meta_data); + SHERPA_ONNX_LOGE("%s", os.str().c_str()); + } + + Ort::AllocatorWithDefaultOptions allocator; // used in the macro below + SHERPA_ONNX_READ_META_DATA(vocab_size_, "vocab_size"); + SHERPA_ONNX_READ_META_DATA(context_size_, "context_size"); +} + +void OnlineZipformer2TransducerModel::InitJoiner(void *model_data, + size_t model_data_length) { + joiner_sess_ = std::make_unique(env_, model_data, + model_data_length, sess_opts_); + + GetInputNames(joiner_sess_.get(), &joiner_input_names_, + &joiner_input_names_ptr_); + + GetOutputNames(joiner_sess_.get(), &joiner_output_names_, + &joiner_output_names_ptr_); + + // get meta data + Ort::ModelMetadata meta_data = joiner_sess_->GetModelMetadata(); + if (config_.debug) { + std::ostringstream os; + os << "---joiner---\n"; + PrintModelMetadata(os, meta_data); + SHERPA_ONNX_LOGE("%s", os.str().c_str()); + } +} + +std::vector OnlineZipformer2TransducerModel::StackStates( + const std::vector> &states) const { + int32_t batch_size = static_cast(states.size()); + int32_t num_encoders = static_cast(num_encoder_layers_.size()); + + std::vector buf(batch_size); + + std::vector ans; + int32_t num_states = static_cast(states[0].size()); + ans.reserve(num_states); + + for (int32_t i = 0; i != (num_states - 2) / 6; ++i) { + { + for (int32_t n = 0; n != batch_size; ++n) { + buf[n] = &states[n][6 * i]; + } + auto v = Cat(allocator_, buf, 1); + ans.push_back(std::move(v)); + } + { + for (int32_t n = 0; n != batch_size; ++n) { + buf[n] = &states[n][6 * i + 1]; + } + auto v = Cat(allocator_, buf, 1); + ans.push_back(std::move(v)); + } + { + for (int32_t n = 0; n != batch_size; ++n) { + buf[n] = &states[n][6 * i + 2]; + } + auto v = Cat(allocator_, buf, 1); + ans.push_back(std::move(v)); + } + { + for (int32_t n = 0; n != batch_size; ++n) { + buf[n] = &states[n][6 * i + 3]; + } + auto v = Cat(allocator_, buf, 1); + ans.push_back(std::move(v)); + } + { + for (int32_t n = 0; n != batch_size; ++n) { + buf[n] = &states[n][6 * i + 4]; + } + auto v = Cat(allocator_, buf, 0); + ans.push_back(std::move(v)); + } + { + for (int32_t n = 0; n != batch_size; ++n) { + buf[n] = &states[n][6 * i + 5]; + } + auto v = Cat(allocator_, buf, 0); + ans.push_back(std::move(v)); + } + } + + { + for (int32_t n = 0; n != batch_size; ++n) { + buf[n] = &states[n][num_states - 2]; + } + auto v = Cat(allocator_, buf, 0); + ans.push_back(std::move(v)); + } + + { + for (int32_t n = 0; n != batch_size; ++n) { + buf[n] = &states[n][num_states - 1]; + } + auto v = Cat(allocator_, buf, 0); + ans.push_back(std::move(v)); + } + return ans; +} + +std::vector> +OnlineZipformer2TransducerModel::UnStackStates( + const std::vector &states) const { + int32_t m = std::accumulate(num_encoder_layers_.begin(), num_encoder_layers_.end(), 0); + assert(states.size() == m * 6 + 2); + + int32_t batch_size = states[0].GetTensorTypeAndShapeInfo().GetShape()[1]; + int32_t num_encoders = num_encoder_layers_.size(); + + std::vector> ans; + ans.resize(batch_size); + + for (int32_t i = 0; i != m; ++i) { + { + auto v = Unbind(allocator_, &states[i * 6], 1); + assert(v.size() == batch_size); + + for (int32_t n = 0; n != batch_size; ++n) { + ans[n].push_back(std::move(v[n])); + } + } + { + auto v = Unbind(allocator_, &states[i * 6 + 1], 1); + assert(v.size() == batch_size); + + for (int32_t n = 0; n != batch_size; ++n) { + ans[n].push_back(std::move(v[n])); + } + } + { + auto v = Unbind(allocator_, &states[i * 6 + 2], 1); + assert(v.size() == batch_size); + + for (int32_t n = 0; n != batch_size; ++n) { + ans[n].push_back(std::move(v[n])); + } + } + { + auto v = Unbind(allocator_, &states[i * 6 + 3], 1); + assert(v.size() == batch_size); + + for (int32_t n = 0; n != batch_size; ++n) { + ans[n].push_back(std::move(v[n])); + } + } + { + auto v = Unbind(allocator_, &states[i * 6 + 4], 0); + assert(v.size() == batch_size); + + for (int32_t n = 0; n != batch_size; ++n) { + ans[n].push_back(std::move(v[n])); + } + } + { + auto v = Unbind(allocator_, &states[i * 6 + 5], 0); + assert(v.size() == batch_size); + + for (int32_t n = 0; n != batch_size; ++n) { + ans[n].push_back(std::move(v[n])); + } + } + } + + { + auto v = Unbind(allocator_, &states[m * 6], 0); + assert(v.size() == batch_size); + + for (int32_t n = 0; n != batch_size; ++n) { + ans[n].push_back(std::move(v[n])); + } + } + { + auto v = Unbind(allocator_, &states[m * 6 + 1], 0); + assert(v.size() == batch_size); + + for (int32_t n = 0; n != batch_size; ++n) { + ans[n].push_back(std::move(v[n])); + } + } + + return ans; +} + +std::vector OnlineZipformer2TransducerModel::GetEncoderInitStates() { + std::vector ans; + int32_t n = static_cast(encoder_dims_.size()); + int32_t m = std::accumulate(num_encoder_layers_.begin(), num_encoder_layers_.end(), 0); + ans.reserve(m * 6 + 2); + + for (int32_t i = 0; i != n; ++i) { + int32_t num_layers = num_encoder_layers_[i]; + int32_t key_dim = query_head_dims_[i] * num_heads_[i]; + int32_t value_dim = value_head_dims_[i] * num_heads_[i]; + int32_t nonlin_attn_head_dim = 3 * encoder_dims_[i] / 4; + + for (int32_t j = 0; j != num_layers; ++j) { + { + std::array s{left_context_len_[i], 1, key_dim}; + auto v = + Ort::Value::CreateTensor(allocator_, s.data(), s.size()); + Fill(&v, 0); + ans.push_back(std::move(v)); + } + + { + std::array s{1, 1, left_context_len_[i], nonlin_attn_head_dim}; + auto v = + Ort::Value::CreateTensor(allocator_, s.data(), s.size()); + Fill(&v, 0); + ans.push_back(std::move(v)); + } + + { + std::array s{left_context_len_[i], 1, value_dim}; + auto v = + Ort::Value::CreateTensor(allocator_, s.data(), s.size()); + Fill(&v, 0); + ans.push_back(std::move(v)); + } + + { + std::array s{left_context_len_[i], 1, value_dim}; + auto v = + Ort::Value::CreateTensor(allocator_, s.data(), s.size()); + Fill(&v, 0); + ans.push_back(std::move(v)); + } + + { + std::array s{1, encoder_dims_[i], cnn_module_kernels_[i] / 2}; + auto v = + Ort::Value::CreateTensor(allocator_, s.data(), s.size()); + Fill(&v, 0); + ans.push_back(std::move(v)); + } + + { + std::array s{1, encoder_dims_[i], cnn_module_kernels_[i] / 2}; + auto v = + Ort::Value::CreateTensor(allocator_, s.data(), s.size()); + Fill(&v, 0); + ans.push_back(std::move(v)); + } + } + } + + { + std::array s{1, 128, 3, 19}; + auto v = Ort::Value::CreateTensor(allocator_, s.data(), s.size()); + Fill(&v, 0); + ans.push_back(std::move(v)); + } + + { + std::array s{1}; + auto v = Ort::Value::CreateTensor(allocator_, s.data(), s.size()); + Fill(&v, 0); + ans.push_back(std::move(v)); + } + return ans; +} + +std::pair> +OnlineZipformer2TransducerModel::RunEncoder(Ort::Value features, + std::vector states, + Ort::Value /* processed_frames */) { + std::vector encoder_inputs; + encoder_inputs.reserve(1 + states.size()); + + encoder_inputs.push_back(std::move(features)); + for (auto &v : states) { + encoder_inputs.push_back(std::move(v)); + } + + auto encoder_out = encoder_sess_->Run( + {}, encoder_input_names_ptr_.data(), encoder_inputs.data(), + encoder_inputs.size(), encoder_output_names_ptr_.data(), + encoder_output_names_ptr_.size()); + + std::vector next_states; + next_states.reserve(states.size()); + + for (int32_t i = 1; i != static_cast(encoder_out.size()); ++i) { + next_states.push_back(std::move(encoder_out[i])); + } + return {std::move(encoder_out[0]), std::move(next_states)}; +} + +Ort::Value OnlineZipformer2TransducerModel::RunDecoder( + Ort::Value decoder_input) { + auto decoder_out = decoder_sess_->Run( + {}, decoder_input_names_ptr_.data(), &decoder_input, 1, + decoder_output_names_ptr_.data(), decoder_output_names_ptr_.size()); + return std::move(decoder_out[0]); +} + +Ort::Value OnlineZipformer2TransducerModel::RunJoiner(Ort::Value encoder_out, + Ort::Value decoder_out) { + std::array joiner_input = {std::move(encoder_out), + std::move(decoder_out)}; + auto logit = + joiner_sess_->Run({}, joiner_input_names_ptr_.data(), joiner_input.data(), + joiner_input.size(), joiner_output_names_ptr_.data(), + joiner_output_names_ptr_.size()); + + return std::move(logit[0]); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/online-zipformer2-transducer-model.h b/sherpa-onnx/csrc/online-zipformer2-transducer-model.h new file mode 100644 index 00000000..57b63e02 --- /dev/null +++ b/sherpa-onnx/csrc/online-zipformer2-transducer-model.h @@ -0,0 +1,109 @@ +// sherpa-onnx/csrc/online-zipformer2-transducer-model.h +// +// Copyright (c) 2023 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_ONLINE_ZIPFORMER2_TRANSDUCER_MODEL_H_ +#define SHERPA_ONNX_CSRC_ONLINE_ZIPFORMER2_TRANSDUCER_MODEL_H_ + +#include +#include +#include +#include + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#include "onnxruntime_cxx_api.h" // NOLINT +#include "sherpa-onnx/csrc/online-transducer-model-config.h" +#include "sherpa-onnx/csrc/online-transducer-model.h" + +namespace sherpa_onnx { + +class OnlineZipformer2TransducerModel : public OnlineTransducerModel { + public: + explicit OnlineZipformer2TransducerModel( + const OnlineTransducerModelConfig &config); + +#if __ANDROID_API__ >= 9 + OnlineZipformer2TransducerModel(AAssetManager *mgr, + const OnlineTransducerModelConfig &config); +#endif + + std::vector StackStates( + const std::vector> &states) const override; + + std::vector> UnStackStates( + const std::vector &states) const override; + + std::vector GetEncoderInitStates() override; + + std::pair> RunEncoder( + Ort::Value features, std::vector states, + Ort::Value processed_frames) override; + + Ort::Value RunDecoder(Ort::Value decoder_input) override; + + Ort::Value RunJoiner(Ort::Value encoder_out, Ort::Value decoder_out) override; + + int32_t ContextSize() const override { return context_size_; } + + int32_t ChunkSize() const override { return T_; } + + int32_t ChunkShift() const override { return decode_chunk_len_; } + + int32_t VocabSize() const override { return vocab_size_; } + OrtAllocator *Allocator() override { return allocator_; } + + private: + void InitEncoder(void *model_data, size_t model_data_length); + void InitDecoder(void *model_data, size_t model_data_length); + void InitJoiner(void *model_data, size_t model_data_length); + + private: + Ort::Env env_; + Ort::SessionOptions sess_opts_; + Ort::AllocatorWithDefaultOptions allocator_; + + std::unique_ptr encoder_sess_; + std::unique_ptr decoder_sess_; + std::unique_ptr joiner_sess_; + + std::vector encoder_input_names_; + std::vector encoder_input_names_ptr_; + + std::vector encoder_output_names_; + std::vector encoder_output_names_ptr_; + + std::vector decoder_input_names_; + std::vector decoder_input_names_ptr_; + + std::vector decoder_output_names_; + std::vector decoder_output_names_ptr_; + + std::vector joiner_input_names_; + std::vector joiner_input_names_ptr_; + + std::vector joiner_output_names_; + std::vector joiner_output_names_ptr_; + + OnlineTransducerModelConfig config_; + + std::vector encoder_dims_; + std::vector query_head_dims_; + std::vector value_head_dims_; + std::vector num_heads_; + std::vector num_encoder_layers_; + std::vector cnn_module_kernels_; + std::vector left_context_len_; + + int32_t T_ = 0; + int32_t decode_chunk_len_ = 0; + + int32_t context_size_ = 0; + int32_t vocab_size_ = 0; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_ONLINE_ZIPFORMER2_TRANSDUCER_MODEL_H_