diff --git a/sherpa-onnx/csrc/CMakeLists.txt b/sherpa-onnx/csrc/CMakeLists.txt index 80804a6f..2bd38404 100644 --- a/sherpa-onnx/csrc/CMakeLists.txt +++ b/sherpa-onnx/csrc/CMakeLists.txt @@ -68,6 +68,7 @@ set(sources online-ctc-fst-decoder.cc online-ctc-greedy-search-decoder.cc online-ctc-model.cc + online-ebranchformer-transducer-model.cc online-lm-config.cc online-lm.cc online-lstm-transducer-model.cc diff --git a/sherpa-onnx/csrc/features.cc b/sherpa-onnx/csrc/features.cc index 5b50d5f2..16632513 100644 --- a/sherpa-onnx/csrc/features.cc +++ b/sherpa-onnx/csrc/features.cc @@ -48,7 +48,9 @@ std::string FeatureExtractorConfig::ToString() const { os << "feature_dim=" << feature_dim << ", "; os << "low_freq=" << low_freq << ", "; os << "high_freq=" << high_freq << ", "; - os << "dither=" << dither << ")"; + os << "dither=" << dither << ", "; + os << "normalize_samples=" << (normalize_samples ? "True" : "False") << ", "; + os << "snip_edges=" << (snip_edges ? "True" : "False") << ")"; return os.str(); } diff --git a/sherpa-onnx/csrc/online-ebranchformer-transducer-model.cc b/sherpa-onnx/csrc/online-ebranchformer-transducer-model.cc new file mode 100644 index 00000000..84d81e8e --- /dev/null +++ b/sherpa-onnx/csrc/online-ebranchformer-transducer-model.cc @@ -0,0 +1,438 @@ +// sherpa-onnx/csrc/online-ebranchformer-transducer-model.cc +// +// Copyright (c) 2023 Xiaomi Corporation +// 2025 Brno University of Technology (author: Karel Vesely) + +#include "sherpa-onnx/csrc/online-ebranchformer-transducer-model.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#if __OHOS__ +#include "rawfile/raw_file_manager.h" +#endif + +#include "onnxruntime_cxx_api.h" // NOLINT +#include "sherpa-onnx/csrc/cat.h" +#include "sherpa-onnx/csrc/file-utils.h" +#include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/csrc/online-transducer-decoder.h" +#include "sherpa-onnx/csrc/onnx-utils.h" +#include "sherpa-onnx/csrc/session.h" +#include "sherpa-onnx/csrc/text-utils.h" +#include "sherpa-onnx/csrc/unbind.h" + +namespace sherpa_onnx { + +OnlineEbranchformerTransducerModel::OnlineEbranchformerTransducerModel( + const OnlineModelConfig &config) + : env_(ORT_LOGGING_LEVEL_ERROR), + encoder_sess_opts_(GetSessionOptions(config)), + decoder_sess_opts_(GetSessionOptions(config, "decoder")), + joiner_sess_opts_(GetSessionOptions(config, "joiner")), + config_(config), + allocator_{} { + { + auto buf = ReadFile(config.transducer.encoder); + InitEncoder(buf.data(), buf.size()); + } + + { + auto buf = ReadFile(config.transducer.decoder); + InitDecoder(buf.data(), buf.size()); + } + + { + auto buf = ReadFile(config.transducer.joiner); + InitJoiner(buf.data(), buf.size()); + } +} + +template +OnlineEbranchformerTransducerModel::OnlineEbranchformerTransducerModel( + Manager *mgr, const OnlineModelConfig &config) + : env_(ORT_LOGGING_LEVEL_ERROR), + config_(config), + encoder_sess_opts_(GetSessionOptions(config)), + decoder_sess_opts_(GetSessionOptions(config)), + joiner_sess_opts_(GetSessionOptions(config)), + allocator_{} { + { + auto buf = ReadFile(mgr, config.transducer.encoder); + InitEncoder(buf.data(), buf.size()); + } + + { + auto buf = ReadFile(mgr, config.transducer.decoder); + InitDecoder(buf.data(), buf.size()); + } + + { + auto buf = ReadFile(mgr, config.transducer.joiner); + InitJoiner(buf.data(), buf.size()); + } +} + + +void OnlineEbranchformerTransducerModel::InitEncoder(void *model_data, + size_t model_data_length) { + encoder_sess_ = std::make_unique( + env_, model_data, model_data_length, encoder_sess_opts_); + + GetInputNames(encoder_sess_.get(), &encoder_input_names_, + &encoder_input_names_ptr_); + + GetOutputNames(encoder_sess_.get(), &encoder_output_names_, + &encoder_output_names_ptr_); + + // get meta data + Ort::ModelMetadata meta_data = encoder_sess_->GetModelMetadata(); + if (config_.debug) { + std::ostringstream os; + os << "---encoder---\n"; + PrintModelMetadata(os, meta_data); +#if __OHOS__ + SHERPA_ONNX_LOGE("%{public}s", os.str().c_str()); +#else + SHERPA_ONNX_LOGE("%s", os.str().c_str()); +#endif + } + + Ort::AllocatorWithDefaultOptions allocator; // used in the macro below + + SHERPA_ONNX_READ_META_DATA(decode_chunk_len_, "decode_chunk_len"); + SHERPA_ONNX_READ_META_DATA(T_, "T"); + + SHERPA_ONNX_READ_META_DATA(num_hidden_layers_, "num_hidden_layers"); + SHERPA_ONNX_READ_META_DATA(hidden_size_, "hidden_size"); + SHERPA_ONNX_READ_META_DATA(intermediate_size_, "intermediate_size"); + SHERPA_ONNX_READ_META_DATA(csgu_kernel_size_, "csgu_kernel_size"); + SHERPA_ONNX_READ_META_DATA(merge_conv_kernel_, "merge_conv_kernel"); + SHERPA_ONNX_READ_META_DATA(left_context_len_, "left_context_len"); + SHERPA_ONNX_READ_META_DATA(num_heads_, "num_heads"); + SHERPA_ONNX_READ_META_DATA(head_dim_, "head_dim"); + + if (config_.debug) { +#if __OHOS__ + SHERPA_ONNX_LOGE("T: %{public}d", T_); + SHERPA_ONNX_LOGE("decode_chunk_len_: %{public}d", decode_chunk_len_); + + SHERPA_ONNX_LOGE("num_hidden_layers_: %{public}d", num_hidden_layers_); + SHERPA_ONNX_LOGE("hidden_size_: %{public}d", hidden_size_); + SHERPA_ONNX_LOGE("intermediate_size_: %{public}d", intermediate_size_); + SHERPA_ONNX_LOGE("csgu_kernel_size_: %{public}d", csgu_kernel_size_); + SHERPA_ONNX_LOGE("merge_conv_kernel_: %{public}d", merge_conv_kernel_); + SHERPA_ONNX_LOGE("left_context_len_: %{public}d", left_context_len_); + SHERPA_ONNX_LOGE("num_heads_: %{public}d", num_heads_); + SHERPA_ONNX_LOGE("head_dim_: %{public}d", head_dim_); +#else + SHERPA_ONNX_LOGE("T: %d", T_); + SHERPA_ONNX_LOGE("decode_chunk_len_: %d", decode_chunk_len_); + + SHERPA_ONNX_LOGE("num_hidden_layers_: %d", num_hidden_layers_); + SHERPA_ONNX_LOGE("hidden_size_: %d", hidden_size_); + SHERPA_ONNX_LOGE("intermediate_size_: %d", intermediate_size_); + SHERPA_ONNX_LOGE("csgu_kernel_size_: %d", csgu_kernel_size_); + SHERPA_ONNX_LOGE("merge_conv_kernel_: %d", merge_conv_kernel_); + SHERPA_ONNX_LOGE("left_context_len_: %d", left_context_len_); + SHERPA_ONNX_LOGE("num_heads_: %d", num_heads_); + SHERPA_ONNX_LOGE("head_dim_: %d", head_dim_); +#endif + } +} + + +void OnlineEbranchformerTransducerModel::InitDecoder(void *model_data, + size_t model_data_length) { + decoder_sess_ = std::make_unique( + env_, model_data, model_data_length, decoder_sess_opts_); + + GetInputNames(decoder_sess_.get(), &decoder_input_names_, + &decoder_input_names_ptr_); + + GetOutputNames(decoder_sess_.get(), &decoder_output_names_, + &decoder_output_names_ptr_); + + // get meta data + Ort::ModelMetadata meta_data = decoder_sess_->GetModelMetadata(); + if (config_.debug) { + std::ostringstream os; + os << "---decoder---\n"; + PrintModelMetadata(os, meta_data); + SHERPA_ONNX_LOGE("%s", os.str().c_str()); + } + + Ort::AllocatorWithDefaultOptions allocator; // used in the macro below + SHERPA_ONNX_READ_META_DATA(vocab_size_, "vocab_size"); + SHERPA_ONNX_READ_META_DATA(context_size_, "context_size"); +} + +void OnlineEbranchformerTransducerModel::InitJoiner(void *model_data, + size_t model_data_length) { + joiner_sess_ = std::make_unique( + env_, model_data, model_data_length, joiner_sess_opts_); + + GetInputNames(joiner_sess_.get(), &joiner_input_names_, + &joiner_input_names_ptr_); + + GetOutputNames(joiner_sess_.get(), &joiner_output_names_, + &joiner_output_names_ptr_); + + // get meta data + Ort::ModelMetadata meta_data = joiner_sess_->GetModelMetadata(); + if (config_.debug) { + std::ostringstream os; + os << "---joiner---\n"; + PrintModelMetadata(os, meta_data); + SHERPA_ONNX_LOGE("%s", os.str().c_str()); + } +} + + +std::vector OnlineEbranchformerTransducerModel::StackStates( + const std::vector> &states) const { + int32_t batch_size = static_cast(states.size()); + + std::vector buf(batch_size); + + auto allocator = + const_cast(this)->allocator_; + + std::vector ans; + int32_t num_states = static_cast(states[0].size()); + ans.reserve(num_states); + + for (int32_t i = 0; i != num_hidden_layers_; ++i) { + { // cached_key + for (int32_t n = 0; n != batch_size; ++n) { + buf[n] = &states[n][4 * i]; + } + auto v = Cat(allocator, buf, /* axis */ 0); + ans.push_back(std::move(v)); + } + { // cached_value + for (int32_t n = 0; n != batch_size; ++n) { + buf[n] = &states[n][4 * i + 1]; + } + auto v = Cat(allocator, buf, 0); + ans.push_back(std::move(v)); + } + { // cached_conv + for (int32_t n = 0; n != batch_size; ++n) { + buf[n] = &states[n][4 * i + 2]; + } + auto v = Cat(allocator, buf, 0); + ans.push_back(std::move(v)); + } + { // cached_conv_fusion + for (int32_t n = 0; n != batch_size; ++n) { + buf[n] = &states[n][4 * i + 3]; + } + auto v = Cat(allocator, buf, 0); + ans.push_back(std::move(v)); + } + } + + { // processed_lens + for (int32_t n = 0; n != batch_size; ++n) { + buf[n] = &states[n][num_states - 1]; + } + auto v = Cat(allocator, buf, 0); + ans.push_back(std::move(v)); + } + + return ans; +} + + +std::vector> +OnlineEbranchformerTransducerModel::UnStackStates( + const std::vector &states) const { + + assert(static_cast(states.size()) == num_hidden_layers_ * 4 + 1); + + int32_t batch_size = states[0].GetTensorTypeAndShapeInfo().GetShape()[0]; + + auto allocator = + const_cast(this)->allocator_; + + std::vector> ans; + ans.resize(batch_size); + + for (int32_t i = 0; i != num_hidden_layers_; ++i) { + { // cached_key + auto v = Unbind(allocator, &states[i * 4], /* axis */ 0); + assert(static_cast(v.size()) == batch_size); + + for (int32_t n = 0; n != batch_size; ++n) { + ans[n].push_back(std::move(v[n])); + } + } + { // cached_value + auto v = Unbind(allocator, &states[i * 4 + 1], 0); + assert(static_cast(v.size()) == batch_size); + + for (int32_t n = 0; n != batch_size; ++n) { + ans[n].push_back(std::move(v[n])); + } + } + { // cached_conv + auto v = Unbind(allocator, &states[i * 4 + 2], 0); + assert(static_cast(v.size()) == batch_size); + + for (int32_t n = 0; n != batch_size; ++n) { + ans[n].push_back(std::move(v[n])); + } + } + { // cached_conv_fusion + auto v = Unbind(allocator, &states[i * 4 + 3], 0); + assert(static_cast(v.size()) == batch_size); + + for (int32_t n = 0; n != batch_size; ++n) { + ans[n].push_back(std::move(v[n])); + } + } + } + + { // processed_lens + auto v = Unbind(allocator, &states.back(), 0); + assert(static_cast(v.size()) == batch_size); + + for (int32_t n = 0; n != batch_size; ++n) { + ans[n].push_back(std::move(v[n])); + } + } + + return ans; +} + + +std::vector +OnlineEbranchformerTransducerModel::GetEncoderInitStates() { + std::vector ans; + + ans.reserve(num_hidden_layers_ * 4 + 1); + + int32_t left_context_conv = csgu_kernel_size_ - 1; + int32_t channels_conv = intermediate_size_ / 2; + + int32_t left_context_conv_fusion = merge_conv_kernel_ - 1; + int32_t channels_conv_fusion = 2 * hidden_size_; + + for (int32_t i = 0; i != num_hidden_layers_; ++i) { + { // cached_key_{i} + std::array s{1, num_heads_, left_context_len_, head_dim_}; + auto v = + Ort::Value::CreateTensor(allocator_, s.data(), s.size()); + Fill(&v, 0); + ans.push_back(std::move(v)); + } + + { // cahced_value_{i} + std::array s{1, num_heads_, left_context_len_, head_dim_}; + auto v = + Ort::Value::CreateTensor(allocator_, s.data(), s.size()); + Fill(&v, 0); + ans.push_back(std::move(v)); + } + + { // cached_conv_{i} + std::array s{1, channels_conv, left_context_conv}; + auto v = + Ort::Value::CreateTensor(allocator_, s.data(), s.size()); + Fill(&v, 0); + ans.push_back(std::move(v)); + } + + { // cached_conv_fusion_{i} + std::array s{1, channels_conv_fusion, left_context_conv_fusion}; + auto v = + Ort::Value::CreateTensor(allocator_, s.data(), s.size()); + Fill(&v, 0); + ans.push_back(std::move(v)); + } + } // num_hidden_layers_ + + { // processed_lens + std::array s{1}; + auto v = Ort::Value::CreateTensor(allocator_, s.data(), s.size()); + Fill(&v, 0); + ans.push_back(std::move(v)); + } + + return ans; +} + + +std::pair> +OnlineEbranchformerTransducerModel::RunEncoder(Ort::Value features, + std::vector states, + Ort::Value /* processed_frames */) { + std::vector encoder_inputs; + encoder_inputs.reserve(1 + states.size()); + + encoder_inputs.push_back(std::move(features)); + for (auto &v : states) { + encoder_inputs.push_back(std::move(v)); + } + + auto encoder_out = encoder_sess_->Run( + {}, encoder_input_names_ptr_.data(), encoder_inputs.data(), + encoder_inputs.size(), encoder_output_names_ptr_.data(), + encoder_output_names_ptr_.size()); + + std::vector next_states; + next_states.reserve(states.size()); + + for (int32_t i = 1; i != static_cast(encoder_out.size()); ++i) { + next_states.push_back(std::move(encoder_out[i])); + } + return {std::move(encoder_out[0]), std::move(next_states)}; +} + + +Ort::Value OnlineEbranchformerTransducerModel::RunDecoder( + Ort::Value decoder_input) { + auto decoder_out = decoder_sess_->Run( + {}, decoder_input_names_ptr_.data(), &decoder_input, 1, + decoder_output_names_ptr_.data(), decoder_output_names_ptr_.size()); + return std::move(decoder_out[0]); +} + + +Ort::Value OnlineEbranchformerTransducerModel::RunJoiner(Ort::Value encoder_out, + Ort::Value decoder_out) { + std::array joiner_input = {std::move(encoder_out), + std::move(decoder_out)}; + auto logit = + joiner_sess_->Run({}, joiner_input_names_ptr_.data(), joiner_input.data(), + joiner_input.size(), joiner_output_names_ptr_.data(), + joiner_output_names_ptr_.size()); + + return std::move(logit[0]); +} + + +#if __ANDROID_API__ >= 9 +template OnlineEbranchformerTransducerModel::OnlineEbranchformerTransducerModel( + AAssetManager *mgr, const OnlineModelConfig &config); +#endif + +#if __OHOS__ +template OnlineEbranchformerTransducerModel::OnlineEbranchformerTransducerModel( + NativeResourceManager *mgr, const OnlineModelConfig &config); +#endif + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/online-ebranchformer-transducer-model.h b/sherpa-onnx/csrc/online-ebranchformer-transducer-model.h new file mode 100644 index 00000000..4329c9f1 --- /dev/null +++ b/sherpa-onnx/csrc/online-ebranchformer-transducer-model.h @@ -0,0 +1,112 @@ +// sherpa-onnx/csrc/online-ebranchformer-transducer-model.h +// +// Copyright (c) 2023 Xiaomi Corporation +// 2025 Brno University of Technology (author: Karel Vesely) +#ifndef SHERPA_ONNX_CSRC_ONLINE_EBRANCHFORMER_TRANSDUCER_MODEL_H_ +#define SHERPA_ONNX_CSRC_ONLINE_EBRANCHFORMER_TRANSDUCER_MODEL_H_ + +#include +#include +#include +#include + +#include "onnxruntime_cxx_api.h" // NOLINT +#include "sherpa-onnx/csrc/online-model-config.h" +#include "sherpa-onnx/csrc/online-transducer-model.h" + +namespace sherpa_onnx { + +class OnlineEbranchformerTransducerModel : public OnlineTransducerModel { + public: + explicit OnlineEbranchformerTransducerModel(const OnlineModelConfig &config); + + template + OnlineEbranchformerTransducerModel(Manager *mgr, + const OnlineModelConfig &config); + + std::vector StackStates( + const std::vector> &states) const override; + + std::vector> UnStackStates( + const std::vector &states) const override; + + std::vector GetEncoderInitStates() override; + + void SetFeatureDim(int32_t feature_dim) override { + feature_dim_ = feature_dim; + } + + std::pair> RunEncoder( + Ort::Value features, std::vector states, + Ort::Value processed_frames) override; + + Ort::Value RunDecoder(Ort::Value decoder_input) override; + + Ort::Value RunJoiner(Ort::Value encoder_out, Ort::Value decoder_out) override; + + int32_t ContextSize() const override { return context_size_; } + + int32_t ChunkSize() const override { return T_; } + + int32_t ChunkShift() const override { return decode_chunk_len_; } + + int32_t VocabSize() const override { return vocab_size_; } + OrtAllocator *Allocator() override { return allocator_; } + + private: + void InitEncoder(void *model_data, size_t model_data_length); + void InitDecoder(void *model_data, size_t model_data_length); + void InitJoiner(void *model_data, size_t model_data_length); + + private: + Ort::Env env_; + Ort::SessionOptions encoder_sess_opts_; + Ort::SessionOptions decoder_sess_opts_; + Ort::SessionOptions joiner_sess_opts_; + + Ort::AllocatorWithDefaultOptions allocator_; + + std::unique_ptr encoder_sess_; + std::unique_ptr decoder_sess_; + std::unique_ptr joiner_sess_; + + std::vector encoder_input_names_; + std::vector encoder_input_names_ptr_; + + std::vector encoder_output_names_; + std::vector encoder_output_names_ptr_; + + std::vector decoder_input_names_; + std::vector decoder_input_names_ptr_; + + std::vector decoder_output_names_; + std::vector decoder_output_names_ptr_; + + std::vector joiner_input_names_; + std::vector joiner_input_names_ptr_; + + std::vector joiner_output_names_; + std::vector joiner_output_names_ptr_; + + OnlineModelConfig config_; + + int32_t decode_chunk_len_ = 0; + int32_t T_ = 0; + + int32_t num_hidden_layers_ = 0; + int32_t hidden_size_ = 0; + int32_t intermediate_size_ = 0; + int32_t csgu_kernel_size_ = 0; + int32_t merge_conv_kernel_ = 0; + int32_t left_context_len_ = 0; + int32_t num_heads_ = 0; + int32_t head_dim_ = 0; + + int32_t context_size_ = 0; + int32_t vocab_size_ = 0; + int32_t feature_dim_ = 80; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_ONLINE_EBRANCHFORMER_TRANSDUCER_MODEL_H_ diff --git a/sherpa-onnx/csrc/online-transducer-model.cc b/sherpa-onnx/csrc/online-transducer-model.cc index 66838225..286fd9cd 100644 --- a/sherpa-onnx/csrc/online-transducer-model.cc +++ b/sherpa-onnx/csrc/online-transducer-model.cc @@ -21,6 +21,7 @@ #include "sherpa-onnx/csrc/file-utils.h" #include "sherpa-onnx/csrc/macros.h" #include "sherpa-onnx/csrc/online-conformer-transducer-model.h" +#include "sherpa-onnx/csrc/online-ebranchformer-transducer-model.h" #include "sherpa-onnx/csrc/online-lstm-transducer-model.h" #include "sherpa-onnx/csrc/online-zipformer-transducer-model.h" #include "sherpa-onnx/csrc/online-zipformer2-transducer-model.h" @@ -30,6 +31,7 @@ namespace { enum class ModelType : std::uint8_t { kConformer, + kEbranchformer, kLstm, kZipformer, kZipformer2, @@ -74,6 +76,8 @@ static ModelType GetModelType(char *model_data, size_t model_data_length, if (model_type == "conformer") { return ModelType::kConformer; + } else if (model_type == "ebranchformer") { + return ModelType::kEbranchformer; } else if (model_type == "lstm") { return ModelType::kLstm; } else if (model_type == "zipformer") { @@ -92,6 +96,8 @@ std::unique_ptr OnlineTransducerModel::Create( const auto &model_type = config.model_type; if (model_type == "conformer") { return std::make_unique(config); + } else if (model_type == "ebranchformer") { + return std::make_unique(config); } else if (model_type == "lstm") { return std::make_unique(config); } else if (model_type == "zipformer") { @@ -115,6 +121,8 @@ std::unique_ptr OnlineTransducerModel::Create( switch (model_type) { case ModelType::kConformer: return std::make_unique(config); + case ModelType::kEbranchformer: + return std::make_unique(config); case ModelType::kLstm: return std::make_unique(config); case ModelType::kZipformer: @@ -171,6 +179,8 @@ std::unique_ptr OnlineTransducerModel::Create( const auto &model_type = config.model_type; if (model_type == "conformer") { return std::make_unique(mgr, config); + } else if (model_type == "ebranchformer") { + return std::make_unique(mgr, config); } else if (model_type == "lstm") { return std::make_unique(mgr, config); } else if (model_type == "zipformer") { @@ -190,6 +200,8 @@ std::unique_ptr OnlineTransducerModel::Create( switch (model_type) { case ModelType::kConformer: return std::make_unique(mgr, config); + case ModelType::kEbranchformer: + return std::make_unique(mgr, config); case ModelType::kLstm: return std::make_unique(mgr, config); case ModelType::kZipformer: diff --git a/sherpa-onnx/python/csrc/features.cc b/sherpa-onnx/python/csrc/features.cc index 63c0143c..78d8169f 100644 --- a/sherpa-onnx/python/csrc/features.cc +++ b/sherpa-onnx/python/csrc/features.cc @@ -11,15 +11,21 @@ namespace sherpa_onnx { static void PybindFeatureExtractorConfig(py::module *m) { using PyClass = FeatureExtractorConfig; py::class_(*m, "FeatureExtractorConfig") - .def(py::init(), - py::arg("sampling_rate") = 16000, py::arg("feature_dim") = 80, - py::arg("low_freq") = 20.0f, py::arg("high_freq") = -400.0f, - py::arg("dither") = 0.0f) + .def(py::init(), + py::arg("sampling_rate") = 16000, + py::arg("feature_dim") = 80, + py::arg("low_freq") = 20.0f, + py::arg("high_freq") = -400.0f, + py::arg("dither") = 0.0f, + py::arg("normalize_samples") = true, + py::arg("snip_edges") = false) .def_readwrite("sampling_rate", &PyClass::sampling_rate) .def_readwrite("feature_dim", &PyClass::feature_dim) .def_readwrite("low_freq", &PyClass::low_freq) .def_readwrite("high_freq", &PyClass::high_freq) .def_readwrite("dither", &PyClass::dither) + .def_readwrite("normalize_samples", &PyClass::normalize_samples) + .def_readwrite("snip_edges", &PyClass::snip_edges) .def("__str__", &PyClass::ToString); } diff --git a/sherpa-onnx/python/csrc/online-stream.cc b/sherpa-onnx/python/csrc/online-stream.cc index 688a64f9..bb811385 100644 --- a/sherpa-onnx/python/csrc/online-stream.cc +++ b/sherpa-onnx/python/csrc/online-stream.cc @@ -22,6 +22,23 @@ Args: to the range [-1, 1]. )"; + +constexpr const char *kGetFramesUsage = R"( +Get n frames starting from the given frame index. +(hint: intended for debugging, for comparing FBANK features across pipelines) + +Args: + frame_index: + The starting frame index + n: + Number of frames to get. +Return: + Return a 2-D tensor of shape (n, feature_dim). + which is flattened into a 1-D vector (flattened in row major). + Unflatten in python with: + `features = np.reshape(arr, (n, feature_dim))` +)"; + void PybindOnlineStream(py::module *m) { using PyClass = OnlineStream; py::class_(*m, "OnlineStream") @@ -34,6 +51,9 @@ void PybindOnlineStream(py::module *m) { py::arg("sample_rate"), py::arg("waveform"), kAcceptWaveformUsage, py::call_guard()) .def("input_finished", &PyClass::InputFinished, + py::call_guard()) + .def("get_frames", &PyClass::GetFrames, + py::arg("frame_index"), py::arg("n"), kGetFramesUsage, py::call_guard()); } diff --git a/sherpa-onnx/python/sherpa_onnx/online_recognizer.py b/sherpa-onnx/python/sherpa_onnx/online_recognizer.py index 321f1cdf..77831de7 100644 --- a/sherpa-onnx/python/sherpa_onnx/online_recognizer.py +++ b/sherpa-onnx/python/sherpa_onnx/online_recognizer.py @@ -50,6 +50,8 @@ class OnlineRecognizer(object): low_freq: float = 20.0, high_freq: float = -400.0, dither: float = 0.0, + normalize_samples: bool = True, + snip_edges: bool = False, enable_endpoint_detection: bool = False, rule1_min_trailing_silence: float = 2.4, rule2_min_trailing_silence: float = 1.2, @@ -118,6 +120,15 @@ class OnlineRecognizer(object): By default the audio samples are in range [-1,+1], so dithering constant 0.00003 is a good value, equivalent to the default 1.0 from kaldi + normalize_samples: + True for +/- 1.0 range of audio samples (default, zipformer feats), + False for +/- 32k samples (ebranchformer features). + snip_edges: + handling of end of audio signal in kaldi feature extraction. + If true, end effects will be handled by outputting only frames that + completely fit in the file, and the number of frames depends on the + frame-length. If false, the number of frames depends only on the + frame-shift, and we reflect the data at the ends. enable_endpoint_detection: True to enable endpoint detection. False to disable endpoint detection. @@ -248,6 +259,8 @@ class OnlineRecognizer(object): feat_config = FeatureExtractorConfig( sampling_rate=sample_rate, + normalize_samples=normalize_samples, + snip_edges=snip_edges, feature_dim=feature_dim, low_freq=low_freq, high_freq=high_freq,