diff --git a/sherpa-onnx/csrc/CMakeLists.txt b/sherpa-onnx/csrc/CMakeLists.txt index aa50c4ab..382541f9 100644 --- a/sherpa-onnx/csrc/CMakeLists.txt +++ b/sherpa-onnx/csrc/CMakeLists.txt @@ -155,7 +155,9 @@ if(SHERPA_ONNX_ENABLE_RKNN) list(APPEND sources ./rknn/online-stream-rknn.cc ./rknn/online-transducer-greedy-search-decoder-rknn.cc + ./rknn/online-zipformer-ctc-model-rknn.cc ./rknn/online-zipformer-transducer-model-rknn.cc + ./rknn/utils.cc ) endif() diff --git a/sherpa-onnx/csrc/online-ctc-decoder.h b/sherpa-onnx/csrc/online-ctc-decoder.h index 65305e6a..4b01d678 100644 --- a/sherpa-onnx/csrc/online-ctc-decoder.h +++ b/sherpa-onnx/csrc/online-ctc-decoder.h @@ -43,12 +43,14 @@ class OnlineCtcDecoder { /** Run streaming CTC decoding given the output from the encoder model. * - * @param log_probs A 3-D tensor of shape (N, T, vocab_size) containing - * lob_probs. + * @param log_probs A 3-D tensor of shape + * (batch_size, num_frames, vocab_size) containing + * lob_probs in row major. * * @param results Input & Output parameters.. */ - virtual void Decode(Ort::Value log_probs, + virtual void Decode(const float *log_probs, int32_t batch_size, + int32_t num_frames, int32_t vocab_size, std::vector *results, OnlineStream **ss = nullptr, int32_t n = 0) = 0; diff --git a/sherpa-onnx/csrc/online-ctc-fst-decoder.cc b/sherpa-onnx/csrc/online-ctc-fst-decoder.cc index dea90918..817450c9 100644 --- a/sherpa-onnx/csrc/online-ctc-fst-decoder.cc +++ b/sherpa-onnx/csrc/online-ctc-fst-decoder.cc @@ -91,30 +91,23 @@ static void DecodeOne(const float *log_probs, int32_t num_rows, processed_frames += num_rows; } -void OnlineCtcFstDecoder::Decode(Ort::Value log_probs, +void OnlineCtcFstDecoder::Decode(const float *log_probs, int32_t batch_size, + int32_t num_frames, int32_t vocab_size, std::vector *results, OnlineStream **ss, int32_t n) { - std::vector log_probs_shape = - log_probs.GetTensorTypeAndShapeInfo().GetShape(); - - if (log_probs_shape[0] != results->size()) { + if (batch_size != results->size()) { SHERPA_ONNX_LOGE("Size mismatch! log_probs.size(0) %d, results.size(0): %d", - static_cast(log_probs_shape[0]), - static_cast(results->size())); + batch_size, static_cast(results->size())); exit(-1); } - if (log_probs_shape[0] != n) { - SHERPA_ONNX_LOGE("Size mismatch! log_probs.size(0) %d, n: %d", - static_cast(log_probs_shape[0]), n); + if (batch_size != n) { + SHERPA_ONNX_LOGE("Size mismatch! log_probs.size(0) %d, n: %d", batch_size, + n); exit(-1); } - int32_t batch_size = static_cast(log_probs_shape[0]); - int32_t num_frames = static_cast(log_probs_shape[1]); - int32_t vocab_size = static_cast(log_probs_shape[2]); - - const float *p = log_probs.GetTensorData(); + const float *p = log_probs; for (int32_t i = 0; i != batch_size; ++i) { DecodeOne(p + i * num_frames * vocab_size, num_frames, vocab_size, diff --git a/sherpa-onnx/csrc/online-ctc-fst-decoder.h b/sherpa-onnx/csrc/online-ctc-fst-decoder.h index 992276d6..430bad93 100644 --- a/sherpa-onnx/csrc/online-ctc-fst-decoder.h +++ b/sherpa-onnx/csrc/online-ctc-fst-decoder.h @@ -19,8 +19,8 @@ class OnlineCtcFstDecoder : public OnlineCtcDecoder { OnlineCtcFstDecoder(const OnlineCtcFstDecoderConfig &config, int32_t blank_id); - void Decode(Ort::Value log_probs, - std::vector *results, + void Decode(const float *log_probs, int32_t batch_size, int32_t num_frames, + int32_t vocab_size, std::vector *results, OnlineStream **ss = nullptr, int32_t n = 0) override; std::unique_ptr CreateFasterDecoder() diff --git a/sherpa-onnx/csrc/online-ctc-greedy-search-decoder.cc b/sherpa-onnx/csrc/online-ctc-greedy-search-decoder.cc index e813c987..0ef14826 100644 --- a/sherpa-onnx/csrc/online-ctc-greedy-search-decoder.cc +++ b/sherpa-onnx/csrc/online-ctc-greedy-search-decoder.cc @@ -13,23 +13,16 @@ namespace sherpa_onnx { void OnlineCtcGreedySearchDecoder::Decode( - Ort::Value log_probs, std::vector *results, + const float *log_probs, int32_t batch_size, int32_t num_frames, + int32_t vocab_size, std::vector *results, OnlineStream ** /*ss=nullptr*/, int32_t /*n = 0*/) { - std::vector log_probs_shape = - log_probs.GetTensorTypeAndShapeInfo().GetShape(); - - if (log_probs_shape[0] != results->size()) { + if (batch_size != results->size()) { SHERPA_ONNX_LOGE("Size mismatch! log_probs.size(0) %d, results.size(0): %d", - static_cast(log_probs_shape[0]), - static_cast(results->size())); + batch_size, static_cast(results->size())); exit(-1); } - int32_t batch_size = static_cast(log_probs_shape[0]); - int32_t num_frames = static_cast(log_probs_shape[1]); - int32_t vocab_size = static_cast(log_probs_shape[2]); - - const float *p = log_probs.GetTensorData(); + const float *p = log_probs; for (int32_t b = 0; b != batch_size; ++b) { auto &r = (*results)[b]; diff --git a/sherpa-onnx/csrc/online-ctc-greedy-search-decoder.h b/sherpa-onnx/csrc/online-ctc-greedy-search-decoder.h index 0af37593..95f42541 100644 --- a/sherpa-onnx/csrc/online-ctc-greedy-search-decoder.h +++ b/sherpa-onnx/csrc/online-ctc-greedy-search-decoder.h @@ -16,8 +16,8 @@ class OnlineCtcGreedySearchDecoder : public OnlineCtcDecoder { explicit OnlineCtcGreedySearchDecoder(int32_t blank_id) : blank_id_(blank_id) {} - void Decode(Ort::Value log_probs, - std::vector *results, + void Decode(const float *log_probs, int32_t batch_size, int32_t num_frames, + int32_t vocab_size, std::vector *results, OnlineStream **ss = nullptr, int32_t n = 0) override; private: diff --git a/sherpa-onnx/csrc/online-model-config.cc b/sherpa-onnx/csrc/online-model-config.cc index 10fb2669..bfac64ec 100644 --- a/sherpa-onnx/csrc/online-model-config.cc +++ b/sherpa-onnx/csrc/online-model-config.cc @@ -76,6 +76,15 @@ bool OnlineModelConfig::Validate() const { transducer.decoder.c_str(), transducer.joiner.c_str()); return false; } + + if (!zipformer2_ctc.model.empty() && + EndsWith(zipformer2_ctc.model, ".rknn")) { + SHERPA_ONNX_LOGE( + "--provider is %s, which is not rknn, but you pass rknn model " + "filename for zipformer2_ctc: '%s'", + provider_config.provider.c_str(), zipformer2_ctc.model.c_str()); + return false; + } } if (provider_config.provider == "rknn") { @@ -89,6 +98,15 @@ bool OnlineModelConfig::Validate() const { transducer.joiner.c_str()); return false; } + + if (!zipformer2_ctc.model.empty() && + EndsWith(zipformer2_ctc.model, ".onnx")) { + SHERPA_ONNX_LOGE( + "--provider rknn, but you pass onnx model filename for " + "zipformer2_ctc: '%s'", + zipformer2_ctc.model.c_str()); + return false; + } } if (!tokens_buf.empty() && FileExists(tokens)) { diff --git a/sherpa-onnx/csrc/online-recognizer-ctc-impl.h b/sherpa-onnx/csrc/online-recognizer-ctc-impl.h index 797d90f0..32f6ac4d 100644 --- a/sherpa-onnx/csrc/online-recognizer-ctc-impl.h +++ b/sherpa-onnx/csrc/online-recognizer-ctc-impl.h @@ -24,12 +24,11 @@ namespace sherpa_onnx { -static OnlineRecognizerResult Convert(const OnlineCtcDecoderResult &src, - const SymbolTable &sym_table, - float frame_shift_ms, - int32_t subsampling_factor, - int32_t segment, - int32_t frames_since_start) { +OnlineRecognizerResult ConvertCtc(const OnlineCtcDecoderResult &src, + const SymbolTable &sym_table, + float frame_shift_ms, + int32_t subsampling_factor, int32_t segment, + int32_t frames_since_start) { OnlineRecognizerResult r; r.tokens.reserve(src.tokens.size()); r.timestamps.reserve(src.tokens.size()); @@ -182,7 +181,10 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl { std::vector> next_states = model_->UnStackStates(std::move(out_states)); - decoder_->Decode(std::move(out[0]), &results, ss, n); + std::vector log_probs_shape = + out[0].GetTensorTypeAndShapeInfo().GetShape(); + decoder_->Decode(out[0].GetTensorData(), log_probs_shape[0], + log_probs_shape[1], log_probs_shape[2], &results, ss, n); for (int32_t k = 0; k != n; ++k) { ss[k]->SetCtcResult(results[k]); @@ -196,8 +198,9 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl { // TODO(fangjun): Remember to change these constants if needed int32_t frame_shift_ms = 10; int32_t subsampling_factor = 4; - auto r = Convert(decoder_result, sym_, frame_shift_ms, subsampling_factor, - s->GetCurrentSegment(), s->GetNumFramesSinceStart()); + auto r = + ConvertCtc(decoder_result, sym_, frame_shift_ms, subsampling_factor, + s->GetCurrentSegment(), s->GetNumFramesSinceStart()); r.text = ApplyInverseTextNormalization(r.text); return r; } @@ -306,7 +309,10 @@ class OnlineRecognizerCtcImpl : public OnlineRecognizerImpl { std::vector results(1); results[0] = std::move(s->GetCtcResult()); - decoder_->Decode(std::move(out[0]), &results, &s, 1); + std::vector log_probs_shape = + out[0].GetTensorTypeAndShapeInfo().GetShape(); + decoder_->Decode(out[0].GetTensorData(), log_probs_shape[0], + log_probs_shape[1], log_probs_shape[2], &results, &s, 1); s->SetCtcResult(results[0]); } diff --git a/sherpa-onnx/csrc/online-recognizer-impl.cc b/sherpa-onnx/csrc/online-recognizer-impl.cc index 810e0a17..c8f4c269 100644 --- a/sherpa-onnx/csrc/online-recognizer-impl.cc +++ b/sherpa-onnx/csrc/online-recognizer-impl.cc @@ -27,6 +27,7 @@ #include "sherpa-onnx/csrc/text-utils.h" #if SHERPA_ONNX_ENABLE_RKNN +#include "sherpa-onnx/csrc/rknn/online-recognizer-ctc-rknn-impl.h" #include "sherpa-onnx/csrc/rknn/online-recognizer-transducer-rknn-impl.h" #endif @@ -37,12 +38,15 @@ std::unique_ptr OnlineRecognizerImpl::Create( if (config.model_config.provider_config.provider == "rknn") { #if SHERPA_ONNX_ENABLE_RKNN // Currently, only zipformer v1 is suported for rknn - if (config.model_config.transducer.encoder.empty()) { + if (config.model_config.transducer.encoder.empty() && + config.model_config.zipformer2_ctc.model.empty()) { SHERPA_ONNX_LOGE( - "Only Zipformer transducers are currently supported by rknn. " - "Fallback to CPU"); - } else { + "Only Zipformer transducers and CTC models are currently supported " + "by rknn. Fallback to CPU"); + } else if (!config.model_config.transducer.encoder.empty()) { return std::make_unique(config); + } else if (!config.model_config.zipformer2_ctc.model.empty()) { + return std::make_unique(config); } #else SHERPA_ONNX_LOGE( diff --git a/sherpa-onnx/csrc/rknn/online-recognizer-ctc-rknn-impl.h b/sherpa-onnx/csrc/rknn/online-recognizer-ctc-rknn-impl.h new file mode 100644 index 00000000..9edd45a4 --- /dev/null +++ b/sherpa-onnx/csrc/rknn/online-recognizer-ctc-rknn-impl.h @@ -0,0 +1,204 @@ +// sherpa-onnx/csrc/rknn/online-recognizer-ctc-rknn-impl.h +// +// Copyright (c) 2025 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_RKNN_ONLINE_RECOGNIZER_CTC_RKNN_IMPL_H_ +#define SHERPA_ONNX_CSRC_RKNN_ONLINE_RECOGNIZER_CTC_RKNN_IMPL_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "sherpa-onnx/csrc/file-utils.h" +#include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/csrc/online-ctc-decoder.h" +#include "sherpa-onnx/csrc/online-ctc-fst-decoder.h" +#include "sherpa-onnx/csrc/online-ctc-greedy-search-decoder.h" +#include "sherpa-onnx/csrc/online-recognizer-impl.h" +#include "sherpa-onnx/csrc/rknn/online-stream-rknn.h" +#include "sherpa-onnx/csrc/rknn/online-zipformer-ctc-model-rknn.h" +#include "sherpa-onnx/csrc/symbol-table.h" + +namespace sherpa_onnx { + +// defined in ../online-recognizer-ctc-impl.h +OnlineRecognizerResult ConvertCtc(const OnlineCtcDecoderResult &src, + const SymbolTable &sym_table, + float frame_shift_ms, + int32_t subsampling_factor, int32_t segment, + int32_t frames_since_start); + +class OnlineRecognizerCtcRknnImpl : public OnlineRecognizerImpl { + public: + explicit OnlineRecognizerCtcRknnImpl(const OnlineRecognizerConfig &config) + : OnlineRecognizerImpl(config), + config_(config), + model_( + std::make_unique(config.model_config)), + endpoint_(config_.endpoint_config) { + if (!config.model_config.tokens_buf.empty()) { + sym_ = SymbolTable(config.model_config.tokens_buf, false); + } else { + /// assuming tokens_buf and tokens are guaranteed not being both empty + sym_ = SymbolTable(config.model_config.tokens, true); + } + + InitDecoder(); + } + + template + explicit OnlineRecognizerCtcRknnImpl(Manager *mgr, + const OnlineRecognizerConfig &config) + : OnlineRecognizerImpl(mgr, config), + config_(config), + model_( + std::make_unique(config.model_config)), + sym_(mgr, config.model_config.tokens), + endpoint_(config_.endpoint_config) { + InitDecoder(); + } + + std::unique_ptr CreateStream() const override { + auto stream = std::make_unique(config_.feat_config); + stream->SetZipformerEncoderStates(model_->GetInitStates()); + stream->SetFasterDecoder(decoder_->CreateFasterDecoder()); + return stream; + } + + bool IsReady(OnlineStream *s) const override { + return s->GetNumProcessedFrames() + model_->ChunkSize() < + s->NumFramesReady(); + } + + void DecodeStreams(OnlineStream **ss, int32_t n) const override { + for (int32_t i = 0; i != n; ++i) { + DecodeStream(reinterpret_cast(ss[i])); + } + } + + OnlineRecognizerResult GetResult(OnlineStream *s) const override { + OnlineCtcDecoderResult decoder_result = s->GetCtcResult(); + + // TODO(fangjun): Remember to change these constants if needed + int32_t frame_shift_ms = 10; + int32_t subsampling_factor = 4; + auto r = + ConvertCtc(decoder_result, sym_, frame_shift_ms, subsampling_factor, + s->GetCurrentSegment(), s->GetNumFramesSinceStart()); + r.text = ApplyInverseTextNormalization(r.text); + return r; + } + + bool IsEndpoint(OnlineStream *s) const override { + if (!config_.enable_endpoint) { + return false; + } + + int32_t num_processed_frames = s->GetNumProcessedFrames(); + + // frame shift is 10 milliseconds + float frame_shift_in_seconds = 0.01; + + // subsampling factor is 4 + int32_t trailing_silence_frames = s->GetCtcResult().num_trailing_blanks * 4; + + return endpoint_.IsEndpoint(num_processed_frames, trailing_silence_frames, + frame_shift_in_seconds); + } + + void Reset(OnlineStream *s) const override { + // segment is incremented only when the last + // result is not empty + const auto &r = s->GetCtcResult(); + if (!r.tokens.empty()) { + s->GetCurrentSegment() += 1; + } + + // clear result + s->SetCtcResult({}); + + // clear states + reinterpret_cast(s)->SetZipformerEncoderStates( + model_->GetInitStates()); + + s->GetFasterDecoderProcessedFrames() = 0; + + // Note: We only update counters. The underlying audio samples + // are not discarded. + s->Reset(); + } + + private: + void InitDecoder() { + if (!sym_.Contains("") && !sym_.Contains("") && + !sym_.Contains("")) { + SHERPA_ONNX_LOGE( + "We expect that tokens.txt contains " + "the symbol or or and its ID."); + exit(-1); + } + + int32_t blank_id = 0; + if (sym_.Contains("")) { + blank_id = sym_[""]; + } else if (sym_.Contains("")) { + // for tdnn models of the yesno recipe from icefall + blank_id = sym_[""]; + } else if (sym_.Contains("")) { + // for WeNet CTC models + blank_id = sym_[""]; + } + + if (!config_.ctc_fst_decoder_config.graph.empty()) { + decoder_ = std::make_unique( + config_.ctc_fst_decoder_config, blank_id); + } else if (config_.decoding_method == "greedy_search") { + decoder_ = std::make_unique(blank_id); + } else { + SHERPA_ONNX_LOGE( + "Unsupported decoding method: %s for streaming CTC models", + config_.decoding_method.c_str()); + exit(-1); + } + } + + void DecodeStream(OnlineStreamRknn *s) const { + int32_t chunk_size = model_->ChunkSize(); + int32_t chunk_shift = model_->ChunkShift(); + + int32_t feat_dim = s->FeatureDim(); + + const auto num_processed_frames = s->GetNumProcessedFrames(); + std::vector features = + s->GetFrames(num_processed_frames, chunk_size); + s->GetNumProcessedFrames() += chunk_shift; + + auto &states = s->GetZipformerEncoderStates(); + auto p = model_->Run(features, std::move(states)); + states = std::move(p.second); + + std::vector results(1); + results[0] = std::move(s->GetCtcResult()); + + auto attr = model_->GetOutAttr(); + + decoder_->Decode(p.first.data(), attr.dims[0], attr.dims[1], attr.dims[2], + &results, reinterpret_cast(&s), 1); + s->SetCtcResult(results[0]); + } + + private: + OnlineRecognizerConfig config_; + std::unique_ptr model_; + std::unique_ptr decoder_; + SymbolTable sym_; + Endpoint endpoint_; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_RKNN_ONLINE_RECOGNIZER_CTC_RKNN_IMPL_H_ diff --git a/sherpa-onnx/csrc/rknn/online-zipformer-ctc-model-rknn.cc b/sherpa-onnx/csrc/rknn/online-zipformer-ctc-model-rknn.cc new file mode 100644 index 00000000..52a7b2ba --- /dev/null +++ b/sherpa-onnx/csrc/rknn/online-zipformer-ctc-model-rknn.cc @@ -0,0 +1,390 @@ +// sherpa-onnx/csrc/rknn/online-zipformer-ctc-model-rknn.cc +// +// Copyright (c) 2025 Xiaomi Corporation + +#include "sherpa-onnx/csrc/rknn/online-zipformer-ctc-model-rknn.h" + +#include +#include +#include +#include +#include +#include + +#if __ANDROID_API__ >= 9 +#include "android/asset_manager.h" +#include "android/asset_manager_jni.h" +#endif + +#if __OHOS__ +#include "rawfile/raw_file_manager.h" +#endif + +#include "sherpa-onnx/csrc/file-utils.h" +#include "sherpa-onnx/csrc/rknn/macros.h" +#include "sherpa-onnx/csrc/rknn/utils.h" +#include "sherpa-onnx/csrc/text-utils.h" + +namespace sherpa_onnx { + +class OnlineZipformerCtcModelRknn::Impl { + public: + ~Impl() { + auto ret = rknn_destroy(ctx_); + if (ret != RKNN_SUCC) { + SHERPA_ONNX_LOGE("Failed to destroy the context"); + } + } + + explicit Impl(const OnlineModelConfig &config) : config_(config) { + { + auto buf = ReadFile(config.zipformer2_ctc.model); + Init(buf.data(), buf.size()); + } + + int32_t ret = RKNN_SUCC; + switch (config_.num_threads) { + case 1: + ret = rknn_set_core_mask(ctx_, RKNN_NPU_CORE_AUTO); + break; + case 0: + ret = rknn_set_core_mask(ctx_, RKNN_NPU_CORE_0); + break; + case -1: + ret = rknn_set_core_mask(ctx_, RKNN_NPU_CORE_1); + break; + case -2: + ret = rknn_set_core_mask(ctx_, RKNN_NPU_CORE_2); + break; + case -3: + ret = rknn_set_core_mask(ctx_, RKNN_NPU_CORE_0_1); + break; + case -4: + ret = rknn_set_core_mask(ctx_, RKNN_NPU_CORE_0_1_2); + break; + default: + SHERPA_ONNX_LOGE( + "Valid num_threads for rk npu is 1 (auto), 0 (core 0), -1 (core " + "1), -2 (core 2), -3 (core 0_1), -4 (core 0_1_2). Given: %d", + config_.num_threads); + break; + } + if (ret != RKNN_SUCC) { + SHERPA_ONNX_LOGE( + "Failed to select npu core to run the model (You can ignore it if " + "you " + "are not using RK3588."); + } + } + + // TODO(fangjun): Support Android + + std::vector> GetInitStates() const { + // input_attrs_[0] is for the feature + // input_attrs_[1:] is for states + // so we use -1 here + std::vector> states(input_attrs_.size() - 1); + + int32_t i = -1; + for (auto &attr : input_attrs_) { + i += 1; + if (i == 0) { + // skip processing the attr for features. + continue; + } + + if (attr.type == RKNN_TENSOR_FLOAT16) { + states[i - 1].resize(attr.n_elems * sizeof(float)); + } else if (attr.type == RKNN_TENSOR_INT64) { + states[i - 1].resize(attr.n_elems * sizeof(int64_t)); + } else { + SHERPA_ONNX_LOGE("Unsupported tensor type: %d, %s", attr.type, + get_type_string(attr.type)); + SHERPA_ONNX_EXIT(-1); + } + } + + return states; + } + + std::pair, std::vector>> Run( + std::vector features, + std::vector> states) const { + std::vector inputs(input_attrs_.size()); + + for (int32_t i = 0; i < static_cast(inputs.size()); ++i) { + auto &input = inputs[i]; + auto &attr = input_attrs_[i]; + input.index = attr.index; + + if (attr.type == RKNN_TENSOR_FLOAT16) { + input.type = RKNN_TENSOR_FLOAT32; + } else if (attr.type == RKNN_TENSOR_INT64) { + input.type = RKNN_TENSOR_INT64; + } else { + SHERPA_ONNX_LOGE("Unsupported tensor type %d, %s", attr.type, + get_type_string(attr.type)); + SHERPA_ONNX_EXIT(-1); + } + + input.fmt = attr.fmt; + if (i == 0) { + input.buf = reinterpret_cast(features.data()); + input.size = features.size() * sizeof(float); + } else { + input.buf = reinterpret_cast(states[i - 1].data()); + input.size = states[i - 1].size(); + } + } + + std::vector out(output_attrs_[0].n_elems); + + // Note(fangjun): We can reuse the memory from input argument `states` + // auto next_states = GetInitStates(); + auto &next_states = states; + + std::vector outputs(output_attrs_.size()); + for (int32_t i = 0; i < outputs.size(); ++i) { + auto &output = outputs[i]; + auto &attr = output_attrs_[i]; + output.index = attr.index; + output.is_prealloc = 1; + + if (attr.type == RKNN_TENSOR_FLOAT16) { + output.want_float = 1; + } else if (attr.type == RKNN_TENSOR_INT64) { + output.want_float = 0; + } else { + SHERPA_ONNX_LOGE("Unsupported tensor type %d, %s", attr.type, + get_type_string(attr.type)); + SHERPA_ONNX_EXIT(-1); + } + + if (i == 0) { + output.size = out.size() * sizeof(float); + output.buf = reinterpret_cast(out.data()); + } else { + output.size = next_states[i - 1].size(); + output.buf = reinterpret_cast(next_states[i - 1].data()); + } + } + + auto ret = rknn_inputs_set(ctx_, inputs.size(), inputs.data()); + SHERPA_ONNX_RKNN_CHECK(ret, "Failed to set inputs"); + + ret = rknn_run(ctx_, nullptr); + SHERPA_ONNX_RKNN_CHECK(ret, "Failed to run the model"); + + ret = rknn_outputs_get(ctx_, outputs.size(), outputs.data(), nullptr); + SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get model output"); + + for (int32_t i = 0; i < next_states.size(); ++i) { + const auto &attr = input_attrs_[i + 1]; + if (attr.n_dims == 4) { + // TODO(fangjun): The transpose is copied from + // https://github.com/airockchip/rknn_model_zoo/blob/main/examples/zipformer/cpp/process.cc#L22 + // I don't understand why we need to do that. + std::vector dst(next_states[i].size()); + int32_t n = attr.dims[0]; + int32_t h = attr.dims[1]; + int32_t w = attr.dims[2]; + int32_t c = attr.dims[3]; + ConvertNCHWtoNHWC( + reinterpret_cast(next_states[i].data()), n, c, h, w, + reinterpret_cast(dst.data())); + next_states[i] = std::move(dst); + } + } + + return {std::move(out), std::move(next_states)}; + } + + int32_t ChunkSize() const { return T_; } + + int32_t ChunkShift() const { return decode_chunk_len_; } + + int32_t VocabSize() const { return vocab_size_; } + + rknn_tensor_attr GetOutAttr() const { return output_attrs_[0]; } + + private: + void Init(void *model_data, size_t model_data_length) { + auto ret = rknn_init(&ctx_, model_data, model_data_length, 0, nullptr); + SHERPA_ONNX_RKNN_CHECK(ret, "Failed to init model '%s'", + config_.zipformer2_ctc.model.c_str()); + + if (config_.debug) { + rknn_sdk_version v; + ret = rknn_query(ctx_, RKNN_QUERY_SDK_VERSION, &v, sizeof(v)); + SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get rknn sdk version"); + + SHERPA_ONNX_LOGE("sdk api version: %s, driver version: %s", v.api_version, + v.drv_version); + } + + rknn_input_output_num io_num; + ret = rknn_query(ctx_, RKNN_QUERY_IN_OUT_NUM, &io_num, sizeof(io_num)); + SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get I/O information for the model"); + + if (config_.debug) { + SHERPA_ONNX_LOGE("model: %d inputs, %d outputs", + static_cast(io_num.n_input), + static_cast(io_num.n_output)); + } + + input_attrs_.resize(io_num.n_input); + output_attrs_.resize(io_num.n_output); + + int32_t i = 0; + for (auto &attr : input_attrs_) { + memset(&attr, 0, sizeof(attr)); + attr.index = i; + ret = rknn_query(ctx_, RKNN_QUERY_INPUT_ATTR, &attr, sizeof(attr)); + SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get attr for model input %d", i); + i += 1; + } + + if (config_.debug) { + std::ostringstream os; + std::string sep; + for (auto &attr : input_attrs_) { + os << sep << ToString(attr); + sep = "\n"; + } + SHERPA_ONNX_LOGE("\n----------Model inputs info----------\n%s", + os.str().c_str()); + } + + i = 0; + for (auto &attr : output_attrs_) { + memset(&attr, 0, sizeof(attr)); + attr.index = i; + ret = rknn_query(ctx_, RKNN_QUERY_OUTPUT_ATTR, &attr, sizeof(attr)); + SHERPA_ONNX_RKNN_CHECK(ret, "Failed to get attr for model output %d", i); + i += 1; + } + + if (config_.debug) { + std::ostringstream os; + std::string sep; + for (auto &attr : output_attrs_) { + os << sep << ToString(attr); + sep = "\n"; + } + SHERPA_ONNX_LOGE("\n----------Model outputs info----------\n%s", + os.str().c_str()); + } + + rknn_custom_string custom_string; + ret = rknn_query(ctx_, RKNN_QUERY_CUSTOM_STRING, &custom_string, + sizeof(custom_string)); + SHERPA_ONNX_RKNN_CHECK(ret, "Failed to read custom string from the model"); + if (config_.debug) { + SHERPA_ONNX_LOGE("customs string: %s", custom_string.string); + } + auto meta = Parse(custom_string); + + if (config_.debug) { + for (const auto &p : meta) { + SHERPA_ONNX_LOGE("%s: %s", p.first.c_str(), p.second.c_str()); + } + } + + if (meta.count("T")) { + T_ = atoi(meta.at("T").c_str()); + } + + if (meta.count("decode_chunk_len")) { + decode_chunk_len_ = atoi(meta.at("decode_chunk_len").c_str()); + } + + vocab_size_ = output_attrs_[0].dims[2]; + + if (config_.debug) { +#if __OHOS__ + SHERPA_ONNX_LOGE("T: %{public}d", T_); + SHERPA_ONNX_LOGE("decode_chunk_len_: %{public}d", decode_chunk_len_); + SHERPA_ONNX_LOGE("vocab_size: %{public}d", vocab_size); +#else + SHERPA_ONNX_LOGE("T: %d", T_); + SHERPA_ONNX_LOGE("decode_chunk_len_: %d", decode_chunk_len_); + SHERPA_ONNX_LOGE("vocab_size: %d", vocab_size_); +#endif + } + + if (T_ == 0) { + SHERPA_ONNX_LOGE( + "Invalid T. Please use the script from icefall to export your model"); + SHERPA_ONNX_EXIT(-1); + } + + if (decode_chunk_len_ == 0) { + SHERPA_ONNX_LOGE( + "Invalid decode_chunk_len. Please use the script from icefall to " + "export your model"); + SHERPA_ONNX_EXIT(-1); + } + } + + private: + OnlineModelConfig config_; + rknn_context ctx_ = 0; + + std::vector input_attrs_; + std::vector output_attrs_; + + int32_t T_ = 0; + int32_t decode_chunk_len_ = 0; + int32_t vocab_size_ = 0; +}; + +OnlineZipformerCtcModelRknn::~OnlineZipformerCtcModelRknn() = default; + +OnlineZipformerCtcModelRknn::OnlineZipformerCtcModelRknn( + const OnlineModelConfig &config) + : impl_(std::make_unique(config)) {} + +template +OnlineZipformerCtcModelRknn::OnlineZipformerCtcModelRknn( + Manager *mgr, const OnlineModelConfig &config) + : impl_(std::make_unique(mgr, config)) {} + +std::vector> OnlineZipformerCtcModelRknn::GetInitStates() + const { + return impl_->GetInitStates(); +} + +std::pair, std::vector>> +OnlineZipformerCtcModelRknn::Run( + std::vector features, + std::vector> states) const { + return impl_->Run(std::move(features), std::move(states)); +} + +int32_t OnlineZipformerCtcModelRknn::ChunkSize() const { + return impl_->ChunkSize(); +} + +int32_t OnlineZipformerCtcModelRknn::ChunkShift() const { + return impl_->ChunkShift(); +} + +int32_t OnlineZipformerCtcModelRknn::VocabSize() const { + return impl_->VocabSize(); +} + +rknn_tensor_attr OnlineZipformerCtcModelRknn::GetOutAttr() const { + return impl_->GetOutAttr(); +} + +#if __ANDROID_API__ >= 9 +template OnlineZipformerCtcModelRknn::OnlineZipformerCtcModelRknn( + AAssetManager *mgr, const OnlineModelConfig &config); +#endif + +#if __OHOS__ +template OnlineZipformerCtcModelRknn::OnlineZipformerCtcModelRknn( + NativeResourceManager *mgr, const OnlineModelConfig &config); +#endif + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/rknn/online-zipformer-ctc-model-rknn.h b/sherpa-onnx/csrc/rknn/online-zipformer-ctc-model-rknn.h new file mode 100644 index 00000000..49242a4e --- /dev/null +++ b/sherpa-onnx/csrc/rknn/online-zipformer-ctc-model-rknn.h @@ -0,0 +1,46 @@ +// sherpa-onnx/csrc/rknn/online-zipformer-ctc-model-rknn.h +// +// Copyright (c) 2025 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_RKNN_ONLINE_ZIPFORMER_CTC_MODEL_RKNN_H_ +#define SHERPA_ONNX_CSRC_RKNN_ONLINE_ZIPFORMER_CTC_MODEL_RKNN_H_ + +#include +#include +#include + +#include "rknn_api.h" // NOLINT +#include "sherpa-onnx/csrc/online-model-config.h" + +namespace sherpa_onnx { + +class OnlineZipformerCtcModelRknn { + public: + ~OnlineZipformerCtcModelRknn(); + + explicit OnlineZipformerCtcModelRknn(const OnlineModelConfig &config); + + template + OnlineZipformerCtcModelRknn(Manager *mgr, const OnlineModelConfig &config); + + std::vector> GetInitStates() const; + + std::pair, std::vector>> Run( + std::vector features, + std::vector> states) const; + + int32_t ChunkSize() const; + + int32_t ChunkShift() const; + + int32_t VocabSize() const; + + rknn_tensor_attr GetOutAttr() const; + + private: + class Impl; + std::unique_ptr impl_; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_RKNN_ONLINE_ZIPFORMER_CTC_MODEL_RKNN_H_ diff --git a/sherpa-onnx/csrc/rknn/online-zipformer-transducer-model-rknn.cc b/sherpa-onnx/csrc/rknn/online-zipformer-transducer-model-rknn.cc index 67a81fdf..7b6d505d 100644 --- a/sherpa-onnx/csrc/rknn/online-zipformer-transducer-model-rknn.cc +++ b/sherpa-onnx/csrc/rknn/online-zipformer-transducer-model-rknn.cc @@ -1,6 +1,6 @@ // sherpa-onnx/csrc/rknn/online-zipformer-transducer-model-rknn.cc // -// Copyright (c) 2023 Xiaomi Corporation +// Copyright (c) 2025 Xiaomi Corporation #include "sherpa-onnx/csrc/rknn/online-zipformer-transducer-model-rknn.h" @@ -22,68 +22,11 @@ #include "sherpa-onnx/csrc/file-utils.h" #include "sherpa-onnx/csrc/rknn/macros.h" +#include "sherpa-onnx/csrc/rknn/utils.h" #include "sherpa-onnx/csrc/text-utils.h" namespace sherpa_onnx { -// chw -> hwc -static void Transpose(const float *src, int32_t n, int32_t channel, - int32_t height, int32_t width, float *dst) { - for (int32_t i = 0; i < n; ++i) { - for (int32_t h = 0; h < height; ++h) { - for (int32_t w = 0; w < width; ++w) { - for (int32_t c = 0; c < channel; ++c) { - // dst[h, w, c] = src[c, h, w] - dst[i * height * width * channel + h * width * channel + w * channel + - c] = src[i * height * width * channel + c * height * width + - h * width + w]; - } - } - } - } -} - -static std::string ToString(const rknn_tensor_attr &attr) { - std::ostringstream os; - os << "{"; - os << attr.index; - os << ", name: " << attr.name; - os << ", shape: ("; - std::string sep; - for (int32_t i = 0; i < static_cast(attr.n_dims); ++i) { - os << sep << attr.dims[i]; - sep = ","; - } - os << ")"; - os << ", n_elems: " << attr.n_elems; - os << ", size: " << attr.size; - os << ", fmt: " << get_format_string(attr.fmt); - os << ", type: " << get_type_string(attr.type); - os << ", pass_through: " << (attr.pass_through ? "true" : "false"); - os << "}"; - return os.str(); -} - -static std::unordered_map Parse( - const rknn_custom_string &custom_string) { - std::unordered_map ans; - std::vector fields; - SplitStringToVector(custom_string.string, ";", false, &fields); - - std::vector tmp; - for (const auto &f : fields) { - SplitStringToVector(f, "=", false, &tmp); - if (tmp.size() != 2) { - SHERPA_ONNX_LOGE("Invalid custom string %s for %s", custom_string.string, - f.c_str()); - SHERPA_ONNX_EXIT(-1); - } - ans[std::move(tmp[0])] = std::move(tmp[1]); - } - - return ans; -} - class OnlineZipformerTransducerModelRknn::Impl { public: ~Impl() { @@ -285,7 +228,7 @@ class OnlineZipformerTransducerModelRknn::Impl { for (int32_t i = 0; i < next_states.size(); ++i) { const auto &attr = encoder_input_attrs_[i + 1]; if (attr.n_dims == 4) { - // TODO(fangjun): The transpose is copied from + // TODO(fangjun): The ConvertNCHWtoNHWC is copied from // https://github.com/airockchip/rknn_model_zoo/blob/main/examples/zipformer/cpp/process.cc#L22 // I don't understand why we need to do that. std::vector dst(next_states[i].size()); @@ -293,8 +236,9 @@ class OnlineZipformerTransducerModelRknn::Impl { int32_t h = attr.dims[1]; int32_t w = attr.dims[2]; int32_t c = attr.dims[3]; - Transpose(reinterpret_cast(next_states[i].data()), n, c, - h, w, reinterpret_cast(dst.data())); + ConvertNCHWtoNHWC( + reinterpret_cast(next_states[i].data()), n, c, h, w, + reinterpret_cast(dst.data())); next_states[i] = std::move(dst); } } @@ -527,11 +471,9 @@ class OnlineZipformerTransducerModelRknn::Impl { #if __OHOS__ SHERPA_ONNX_LOGE("T: %{public}d", T_); SHERPA_ONNX_LOGE("decode_chunk_len_: %{public}d", decode_chunk_len_); - SHERPA_ONNX_LOGE("context_size: %{public}d", context_size_); #else SHERPA_ONNX_LOGE("T: %d", T_); SHERPA_ONNX_LOGE("decode_chunk_len_: %d", decode_chunk_len_); - SHERPA_ONNX_LOGE("context_size: %d", context_size_); #endif } } @@ -597,6 +539,11 @@ class OnlineZipformerTransducerModelRknn::Impl { SHERPA_ONNX_EXIT(-1); } + context_size_ = decoder_input_attrs_[0].dims[1]; + if (config_.debug) { + SHERPA_ONNX_LOGE("context_size: %d", context_size_); + } + i = 0; for (auto &attr : decoder_output_attrs_) { memset(&attr, 0, sizeof(attr)); diff --git a/sherpa-onnx/csrc/rknn/online-zipformer-transducer-model-rknn.h b/sherpa-onnx/csrc/rknn/online-zipformer-transducer-model-rknn.h index bc821afa..e99aaf98 100644 --- a/sherpa-onnx/csrc/rknn/online-zipformer-transducer-model-rknn.h +++ b/sherpa-onnx/csrc/rknn/online-zipformer-transducer-model-rknn.h @@ -14,8 +14,11 @@ namespace sherpa_onnx { -// this is for zipformer v1, i.e., the folder -// pruned_transducer_statelss7_streaming from icefall +// this is for zipformer v1 and v2, i.e., the folder +// pruned_transducer_statelss7_streaming +// and +// zipformer +// from icefall class OnlineZipformerTransducerModelRknn { public: ~OnlineZipformerTransducerModelRknn(); diff --git a/sherpa-onnx/csrc/rknn/utils.cc b/sherpa-onnx/csrc/rknn/utils.cc new file mode 100644 index 00000000..e1581ed4 --- /dev/null +++ b/sherpa-onnx/csrc/rknn/utils.cc @@ -0,0 +1,73 @@ +// sherpa-onnx/csrc/utils.cc +// +// Copyright 2025 Xiaomi Corporation + +#include "sherpa-onnx/csrc/rknn/utils.h" + +#include +#include +#include + +#include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/csrc/text-utils.h" + +namespace sherpa_onnx { + +void ConvertNCHWtoNHWC(const float *src, int32_t n, int32_t channel, + int32_t height, int32_t width, float *dst) { + for (int32_t i = 0; i < n; ++i) { + for (int32_t h = 0; h < height; ++h) { + for (int32_t w = 0; w < width; ++w) { + for (int32_t c = 0; c < channel; ++c) { + // dst[h, w, c] = src[c, h, w] + dst[i * height * width * channel + h * width * channel + w * channel + + c] = src[i * height * width * channel + c * height * width + + h * width + w]; + } + } + } + } +} + +std::string ToString(const rknn_tensor_attr &attr) { + std::ostringstream os; + os << "{"; + os << attr.index; + os << ", name: " << attr.name; + os << ", shape: ("; + std::string sep; + for (int32_t i = 0; i < static_cast(attr.n_dims); ++i) { + os << sep << attr.dims[i]; + sep = ","; + } + os << ")"; + os << ", n_elems: " << attr.n_elems; + os << ", size: " << attr.size; + os << ", fmt: " << get_format_string(attr.fmt); + os << ", type: " << get_type_string(attr.type); + os << ", pass_through: " << (attr.pass_through ? "true" : "false"); + os << "}"; + return os.str(); +} + +std::unordered_map Parse( + const rknn_custom_string &custom_string) { + std::unordered_map ans; + std::vector fields; + SplitStringToVector(custom_string.string, ";", false, &fields); + + std::vector tmp; + for (const auto &f : fields) { + SplitStringToVector(f, "=", false, &tmp); + if (tmp.size() != 2) { + SHERPA_ONNX_LOGE("Invalid custom string %s for %s", custom_string.string, + f.c_str()); + SHERPA_ONNX_EXIT(-1); + } + ans[std::move(tmp[0])] = std::move(tmp[1]); + } + + return ans; +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/rknn/utils.h b/sherpa-onnx/csrc/rknn/utils.h new file mode 100644 index 00000000..077d3f65 --- /dev/null +++ b/sherpa-onnx/csrc/rknn/utils.h @@ -0,0 +1,23 @@ +// sherpa-onnx/csrc/utils.h +// +// Copyright 2025 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_RKNN_UTILS_H_ +#define SHERPA_ONNX_CSRC_RKNN_UTILS_H_ + +#include +#include + +#include "rknn_api.h" // NOLINT + +namespace sherpa_onnx { +void ConvertNCHWtoNHWC(const float *src, int32_t n, int32_t channel, + int32_t height, int32_t width, float *dst); + +std::string ToString(const rknn_tensor_attr &attr); + +std::unordered_map Parse( + const rknn_custom_string &custom_string); +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_RKNN_UTILS_H_ diff --git a/sherpa-onnx/csrc/sherpa-onnx.cc b/sherpa-onnx/csrc/sherpa-onnx.cc index 7f0db215..3af4f412 100644 --- a/sherpa-onnx/csrc/sherpa-onnx.cc +++ b/sherpa-onnx/csrc/sherpa-onnx.cc @@ -83,6 +83,7 @@ for a list of pre-trained models to download. po.Read(argc, argv); if (po.NumArgs() < 1) { po.PrintUsage(); + fprintf(stderr, "Error! Please provide at lease 1 wav file\n"); exit(EXIT_FAILURE); }